matcher.hpp 12.4 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
6
7
8
9
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
10
#include <unordered_map>
Paul's avatar
Paul committed
11
#include <unordered_set>
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
namespace migraphx {
Paul's avatar
Paul committed
14
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
namespace match {
Paul's avatar
Paul committed
17

Paul's avatar
Paul committed
18
19
struct matcher_context
{
Paul's avatar
Paul committed
20
    matcher_context(instruction_ref i) : last(i) {}
Paul's avatar
Paul committed
21
    std::unordered_map<std::string, instruction_ref> instructions;
Paul's avatar
Paul committed
22
23
    instruction_ref not_found() const { return last; }

Paul's avatar
Paul committed
24
    template <class M>
Paul's avatar
Paul committed
25
26
27
28
29
    bool matched(M m, instruction_ref ins)
    {
        return m.match(*this, ins) != this->not_found();
    }

Paul's avatar
Paul committed
30
    private:
Paul's avatar
Paul committed
31
    instruction_ref last;
Paul's avatar
Paul committed
32
33
};

Paul's avatar
Paul committed
34
/// Convert a predicate function into a matcher
Paul's avatar
Paul committed
35
template <class P>
Paul's avatar
Paul committed
36
37
38
39
40
41
42
43
44
45
46
47
48
struct predicate_matcher
{
    P p;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        if(p(ins))
            return ins;
        return ctx.not_found();
    }
};

Paul's avatar
Paul committed
49
/// Convert a function into a matcher
Paul's avatar
Paul committed
50
template <class F>
Paul's avatar
Paul committed
51
52
53
54
55
56
57
58
59
60
61
struct function_matcher
{
    F f;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        return f(ctx, ins);
    }
};

Paul's avatar
Paul committed
62
/// Convert a function into a matcher
Paul's avatar
Paul committed
63
template <class F>
Paul's avatar
Paul committed
64
65
66
67
68
function_matcher<F> make_function_matcher(F f)
{
    return {f};
}

Paul's avatar
Paul committed
69
/// Converts a matcher to bind the instruction to name
Paul's avatar
Paul committed
70
template <class M>
Paul's avatar
Paul committed
71
72
auto bind_match(M m, std::string name)
{
Paul's avatar
Paul committed
73
74
75
76
77
78
79
    return make_function_matcher(
        [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
            auto result = m.match(ctx, ins);
            if(result != ctx.not_found())
                ctx.instructions.emplace(name, ins);
            return result;
        });
Paul's avatar
Paul committed
80
81
}

Paul's avatar
Paul committed
82
/// Convert a matcher to a bindable matcher
Paul's avatar
Paul committed
83
template <class M>
Paul's avatar
Paul committed
84
85
86
87
struct bindable_matcher
{
    M m;

Paul's avatar
Paul committed
88
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
89
90
91
92
93
94
95

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
96
/// Create a bindable matcher
Paul's avatar
Paul committed
97
template <class M>
Paul's avatar
Paul committed
98
99
100
101
102
bindable_matcher<M> make_bindable_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
103
/// Create a bindable matcher from a function
Paul's avatar
Paul committed
104
template <class F>
Paul's avatar
Paul committed
105
106
107
108
109
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
110
/// Create a bindable matcher from a predicate function
Paul's avatar
Paul committed
111
template <class F>
Paul's avatar
Paul committed
112
113
114
115
116
117
118
119
120
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
    return {{f}};
}

using bool_list = std::initializer_list<bool>;

struct id_matcher
{
Paul's avatar
Paul committed
121
    instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
Paul's avatar
Paul committed
122
123
};

Paul's avatar
Paul committed
124
/// The basic matcher provides the all_of composability of the matcher
Paul's avatar
Paul committed
125
template <class M>
Paul's avatar
Paul committed
126
127
128
129
struct basic_matcher
{
    M m;

Paul's avatar
Paul committed
130
    template <class... Ts>
Paul's avatar
Paul committed
131
132
133
134
135
136
    auto operator()(Ts... ms) const
    {
        // Copy m because we cant capture `this` by value
        auto mm = m;
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
            auto result = mm.match(ctx, ins);
Paul's avatar
Paul committed
137
            if(result != ctx.not_found())
Paul's avatar
Paul committed
138
139
140
141
142
143
144
145
146
147
148
            {
                bool matches = fold([&](auto x, auto y) {
                    return x and y.match(ctx, result) != ctx.not_found();
                })(true, ms...);
                if(matches)
                    return result;
            }
            return ctx.not_found();
        });
    }

Paul's avatar
Paul committed
149
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
150
151
152
153
154
155
156

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
157
/// Create a basic matcher from a matcher
Paul's avatar
Paul committed
158
template <class M>
Paul's avatar
Paul committed
159
160
161
162
163
basic_matcher<M> make_basic_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
164
/// Create a basic matcher from a function
Paul's avatar
Paul committed
165
template <class F>
Paul's avatar
Paul committed
166
167
168
169
170
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
171
/// Create a basic matcher from a predicate function
Paul's avatar
Paul committed
172
template <class P>
Paul's avatar
Paul committed
173
174
175
176
177
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
    return {{p}};
}

Paul's avatar
Paul committed
178
/// This macro takes care of the boilerplate for defining a matcher
Paul's avatar
Paul committed
179
#define MIGRAPHX_BASIC_MATCHER(name, ...)                                     \
Paul's avatar
Paul committed
180
181
182
183
    struct name##_m                                                           \
    {                                                                         \
        instruction_ref match(__VA_ARGS__) const;                             \
    };                                                                        \
Paul's avatar
Paul committed
184
    const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
Paul's avatar
Paul committed
185
186
    inline instruction_ref name##_m::match(__VA_ARGS__) const

Paul's avatar
Paul committed
187
/// This macro takes care of the boilerplate for defining a predicate matcher
Paul's avatar
Paul committed
188
#define MIGRAPHX_PRED_MATCHER(name, ...)                                                  \
Paul's avatar
Paul committed
189
190
191
192
193
    struct name##_m                                                                       \
    {                                                                                     \
        bool operator()(__VA_ARGS__) const;                                               \
    };                                                                                    \
    const constexpr auto name =                                                           \
Paul's avatar
Paul committed
194
        migraphx::match::basic_matcher<migraphx::match::predicate_matcher<name##_m>>{{}}; \
Paul's avatar
Paul committed
195
    inline bool name##_m::operator()(__VA_ARGS__) const
Paul's avatar
Paul committed
196
197
198
199
200
201
202

struct matcher_result
{
    std::unordered_map<std::string, instruction_ref> instructions;
    instruction_ref result;
};

Paul's avatar
Paul committed
203
/// Match a single instruction
Paul's avatar
Paul committed
204
template <class M>
Paul's avatar
Paul committed
205
206
207
208
209
matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
{
    assert(ins != p.end());
    matcher_result result;
    matcher_context ctx{p.end()};
Paul's avatar
Paul committed
210
    result.result       = m.match(ctx, ins);
Paul's avatar
Paul committed
211
    result.instructions = ctx.instructions;
Paul's avatar
Paul committed
212
    return result;
Paul's avatar
Paul committed
213
214
}

Paul's avatar
Paul committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
/// Find matches for an instruction in the program
template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{
    bool match = false;
    each_args(
        [&](auto&& m) {
            if(match)
                return;
            auto r = match_instruction(p, ins, m.matcher());
            if(r.result == p.end())
                return;
            m.apply(p, r);
            match = true;
        },
        ms...);
}

Paul's avatar
Paul committed
233
/// Find matches in a program
Paul's avatar
Paul committed
234
template <class... Ms>
Paul's avatar
Paul committed
235
236
void find_matches(program& p, Ms&&... ms)
{
Paul's avatar
Paul committed
237
    for(auto ins : iterator_for(p))
Paul's avatar
Paul committed
238
    {
Paul's avatar
Paul committed
239
        find_matches(p, ins, ms...);
Paul's avatar
Paul committed
240
241
242
    }
}

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
struct lazy_and
{
    template<class F, class G>
    bool operator()(F f, G g) const
    {
        return f() and g();
    }
};

struct lazy_or
{
    template<class F, class G>
    bool operator()(F f, G g) const
    {
        return f() or g();
    }
};

Paul's avatar
Paul committed
261
template <class Op, bool Start, bool Matches>
Paul's avatar
Paul committed
262
struct folder
Paul's avatar
Paul committed
263
{
Paul's avatar
Paul committed
264
265
266
267
268
    template <class... Ts>
    auto operator()(Ts... ms) const
    {
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
            Op op;
269
270
271
            auto matched = [&](auto m) {
                return [&]{ return ctx.matched(m, ins); };
            };
Paul's avatar
Paul committed
272
            bool matches = fold([&](auto x, auto y) {
273
                return op(always(x), matched(y));
Paul's avatar
Paul committed
274
275
276
277
278
279
            })(Start, ms...);
            if(matches == Matches)
                return ins;
            return ctx.not_found();
        });
    }
Paul's avatar
Paul committed
280

Paul's avatar
Paul committed
281
    template <class Selector>
Paul's avatar
Paul committed
282
283
284
285
286
287
288
    auto operator[](Selector select) const
    {
        return [=](auto... ms) {
            return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
                Op op;
                bool matches = Start;
                select(start, [&](auto ins) {
289
290
291
292
293
294
295
296
297
                    auto matched = [&](auto m) {
                        return [&]{ return ctx.matched(m, ins); };
                    };
                    auto fold_match = [&] {
                        return fold([&](auto x, auto y) {
                                     return op(always(x), matched(y));
                                 })(Start, ms...);
                    };
                    matches = op(always(matches), fold_match);
Paul's avatar
Paul committed
298
299
300
301
302
303
304
305
306
                });
                if(matches == Matches)
                    return start;
                return ctx.not_found();
            });
        };
    }
};

307
308
309
const constexpr auto all_of  = folder<lazy_and, true, true>{};
const constexpr auto any_of  = folder<lazy_or, false, true>{};
const constexpr auto none_of = folder<lazy_or, false, false>{};
Paul's avatar
Paul committed
310
311

inline auto inputs()
Paul's avatar
Paul committed
312
{
Paul's avatar
Paul committed
313
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
314
        for(auto&& x : ins->inputs())
Paul's avatar
Paul committed
315
316
            f(x);
    };
Paul's avatar
Paul committed
317
318
}

Paul's avatar
Paul committed
319
inline auto outputs()
Paul's avatar
Paul committed
320
{
Paul's avatar
Paul committed
321
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
322
        for(auto&& x : ins->outputs())
Paul's avatar
Paul committed
323
324
            f(x);
    };
Paul's avatar
Paul committed
325
326
}

Paul's avatar
Paul committed
327
328
329
330
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
Paul's avatar
Paul committed
331
332
333
{
    return ins->get_shape().broadcasted();
}
Paul's avatar
Paul committed
334

Paul's avatar
Paul committed
335
336
337
338
339
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
    return ins->get_shape().transposed();
}

Paul's avatar
Paul committed
340
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
Paul's avatar
Paul committed
341
{
Paul's avatar
Paul committed
342
    if(ins->inputs().empty())
Paul's avatar
Paul committed
343
344
        return false;
    auto s = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
345
346
    return std::all_of(
        ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
Paul's avatar
Paul committed
347
348
}

Paul's avatar
Paul committed
349
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
Paul's avatar
Add cbr  
Paul committed
350
351
352
353
354
355
{
    if(ins->outputs().size() == 1)
        return ins->outputs().front();
    return ctx.not_found();
}

Paul's avatar
Paul committed
356
MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
Paul's avatar
Paul committed
357
358
359
360
361
362
363
364
{
    if(ins->outputs().size() == 1)
        return ins;
    if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
        return ins;
    return ctx.not_found();
}

Paul's avatar
Paul committed
365
template <class... Ms>
Paul's avatar
Paul committed
366
367
368
369
370
371
372
373
auto skip_output(Ms... ms)
{
    auto m = any_of(ms...);
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
        return fix<instruction_ref>([&](auto self, auto ins) {
            if(ins->outputs().size() == 1)
            {
                auto next = ins->outputs().front();
Paul's avatar
Paul committed
374
                if(ctx.matched(m, next))
Paul's avatar
Paul committed
375
376
                {
                    auto skipped_next = self(next);
Paul's avatar
Paul committed
377
                    if(skipped_next != ctx.not_found())
Paul's avatar
Paul committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
                        return skipped_next;
                }
                return next;
            }
            return ctx.not_found();
        })(start);
    });
}

inline auto name(std::string s)
{
    return make_basic_pred_matcher(
        [ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}

inline auto name(std::unordered_set<std::string> names)
Paul's avatar
Paul committed
394
{
Paul's avatar
Paul committed
395
396
397
    return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
        return names.count(ins->name()) > 0;
    });
Paul's avatar
Paul committed
398
399
}

400
401
402
403
404
inline auto nargs(std::size_t n)
{
    return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
}

Paul's avatar
Paul committed
405
406
407
408
409
410
411
412
413
414
inline auto arg(std::size_t i)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        if(i < ins->inputs().size())
            return ins->inputs()[i];
        return ctx.not_found();
    });
}

// Workaround for bugs in clang
Paul's avatar
Paul committed
415
416
417
418
template <std::size_t...>
struct args_impl_ints
{
};
Paul's avatar
Paul committed
419

Paul's avatar
Paul committed
420
template <std::size_t... Ns, class... Ms>
Paul's avatar
Paul committed
421
422
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
Paul's avatar
Paul committed
423
    return match::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
Paul's avatar
Paul committed
424
425
}

Paul's avatar
Paul committed
426
template <class... Ms>
Paul's avatar
Paul committed
427
428
429
430
431
auto args(Ms... ms)
{
    return sequence_c<sizeof...(Ms)>([=](auto... is) {
        // It needs to be written as `decltype(is)::value` for gcc 5
        return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
Paul's avatar
Paul committed
432
    });
Paul's avatar
Paul committed
433
434
}

Paul's avatar
Paul committed
435
inline auto either_arg(std::size_t i, std::size_t j)
Paul's avatar
Paul committed
436
437
{
    return [=](auto m1, auto m2) {
Paul's avatar
Paul committed
438
439
        return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)),
                             match::all_of(arg(j)(m1), arg(i)(m2)));
Paul's avatar
Paul committed
440
441
442
    };
}

Paul's avatar
Paul committed
443
template <class M>
Paul's avatar
Paul committed
444
445
446
447
auto same_shape(M m)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        auto i = m.match(ctx, ins);
Paul's avatar
Paul committed
448
        if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
Paul's avatar
Paul committed
449
450
451
452
453
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
454
template <class... Ms>
Paul's avatar
Paul committed
455
456
457
458
459
auto same_shape(Ms... ms)
{
    return all_of(same_shape(ms)...);
}

Paul's avatar
Paul committed
460
} // namespace match
Paul's avatar
Paul committed
461
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
462
} // namespace migraphx
Paul's avatar
Paul committed
463
464

#endif