matcher.hpp 19 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
6
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
7
#include <migraphx/module.hpp>
Paul's avatar
Paul committed
8
9
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
10
#include <migraphx/type_name.hpp>
Paul's avatar
Paul committed
11
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
12
#include <unordered_map>
Paul's avatar
Paul committed
13
#include <unordered_set>
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
namespace migraphx {
Paul's avatar
Paul committed
16
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
17

Paul's avatar
Paul committed
18
namespace match {
Paul's avatar
Paul committed
19

Paul's avatar
Paul committed
20
21
struct matcher_context
{
Paul's avatar
Paul committed
22
    matcher_context(instruction_ref i) : last(i) {}
Paul's avatar
Paul committed
23
    std::unordered_map<std::string, instruction_ref> instructions;
Paul's avatar
Paul committed
24
25
    instruction_ref not_found() const { return last; }

Paul's avatar
Paul committed
26
    template <class M>
Paul's avatar
Paul committed
27
28
29
30
31
    bool matched(M m, instruction_ref ins)
    {
        return m.match(*this, ins) != this->not_found();
    }

Paul Fultz II's avatar
Paul Fultz II committed
32
33
34
35
36
37
    template <class M>
    auto lazy_match(M m, instruction_ref ins)
    {
        return [=] { return this->matched(m, ins); };
    }

Paul's avatar
Paul committed
38
    private:
Paul's avatar
Paul committed
39
    instruction_ref last;
Paul's avatar
Paul committed
40
41
};

Paul's avatar
Paul committed
42
/// Convert a predicate function into a matcher
Paul's avatar
Paul committed
43
template <class P>
Paul's avatar
Paul committed
44
45
46
47
struct predicate_matcher
{
    P p;

Paul Fultz II's avatar
Paul Fultz II committed
48
    instruction_ref match(const matcher_context& ctx, instruction_ref ins) const
Paul's avatar
Paul committed
49
50
51
52
53
54
55
56
    {
        assert(ins != ctx.not_found());
        if(p(ins))
            return ins;
        return ctx.not_found();
    }
};

Paul's avatar
Paul committed
57
/// Convert a function into a matcher
Paul's avatar
Paul committed
58
template <class F>
Paul's avatar
Paul committed
59
60
61
62
63
64
65
66
67
68
69
struct function_matcher
{
    F f;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        return f(ctx, ins);
    }
};

Paul's avatar
Paul committed
70
/// Convert a function into a matcher
Paul's avatar
Paul committed
71
template <class F>
Paul's avatar
Paul committed
72
73
74
75
76
function_matcher<F> make_function_matcher(F f)
{
    return {f};
}

Paul's avatar
Paul committed
77
/// Converts a matcher to bind the instruction to name
Paul's avatar
Paul committed
78
template <class M>
Paul's avatar
Paul committed
79
80
auto bind_match(M m, std::string name)
{
Paul's avatar
Paul committed
81
82
83
84
    return make_function_matcher(
        [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
            auto result = m.match(ctx, ins);
            if(result != ctx.not_found())
Paul's avatar
Paul committed
85
                ctx.instructions[name] = ins;
Paul's avatar
Paul committed
86
87
            return result;
        });
Paul's avatar
Paul committed
88
89
}

Paul's avatar
Paul committed
90
/// Convert a matcher to a bindable matcher
Paul's avatar
Paul committed
91
template <class M>
Paul's avatar
Paul committed
92
93
94
95
struct bindable_matcher
{
    M m;

Paul's avatar
Paul committed
96
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
97
98
99
100
101
102
103

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
104
/// Create a bindable matcher
Paul's avatar
Paul committed
105
template <class M>
Paul's avatar
Paul committed
106
107
108
109
110
bindable_matcher<M> make_bindable_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
111
/// Create a bindable matcher from a function
Paul's avatar
Paul committed
112
template <class F>
Paul's avatar
Paul committed
113
114
115
116
117
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
118
/// Create a bindable matcher from a predicate function
Paul's avatar
Paul committed
119
template <class F>
Paul's avatar
Paul committed
120
121
122
123
124
125
126
127
128
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
    return {{f}};
}

using bool_list = std::initializer_list<bool>;

struct id_matcher
{
Paul's avatar
Paul committed
129
    instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
Paul's avatar
Paul committed
130
131
};

Paul's avatar
Paul committed
132
/// The basic matcher provides the all_of composability of the matcher
Paul's avatar
Paul committed
133
template <class M>
Paul's avatar
Paul committed
134
135
136
137
struct basic_matcher
{
    M m;

Paul's avatar
Paul committed
138
    template <class... Ts>
Paul's avatar
Paul committed
139
140
141
142
143
144
    auto operator()(Ts... ms) const
    {
        // Copy m because we cant capture `this` by value
        auto mm = m;
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
            auto result = mm.match(ctx, ins);
Paul's avatar
Paul committed
145
            if(result != ctx.not_found())
Paul's avatar
Paul committed
146
147
148
149
150
151
152
153
154
155
156
            {
                bool matches = fold([&](auto x, auto y) {
                    return x and y.match(ctx, result) != ctx.not_found();
                })(true, ms...);
                if(matches)
                    return result;
            }
            return ctx.not_found();
        });
    }

Paul's avatar
Paul committed
157
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
158
159
160
161
162
163
164

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
165
/// Create a basic matcher from a matcher
Paul's avatar
Paul committed
166
template <class M>
Paul's avatar
Paul committed
167
168
169
170
171
basic_matcher<M> make_basic_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
172
/// Create a basic matcher from a function
Paul's avatar
Paul committed
173
template <class F>
Paul's avatar
Paul committed
174
175
176
177
178
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
179
/// Create a basic matcher from a predicate function
Paul's avatar
Paul committed
180
template <class P>
Paul's avatar
Paul committed
181
182
183
184
185
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
    return {{p}};
}

Paul's avatar
Paul committed
186
/// This macro takes care of the boilerplate for defining a matcher
Paul's avatar
Paul committed
187
#define MIGRAPHX_BASIC_MATCHER(name, ...)                                     \
Paul's avatar
Paul committed
188
189
190
191
    struct name##_m                                                           \
    {                                                                         \
        instruction_ref match(__VA_ARGS__) const;                             \
    };                                                                        \
Paul's avatar
Paul committed
192
    const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
Paul's avatar
Paul committed
193
194
    inline instruction_ref name##_m::match(__VA_ARGS__) const

Paul's avatar
Paul committed
195
/// This macro takes care of the boilerplate for defining a predicate matcher
Paul's avatar
Paul committed
196
#define MIGRAPHX_PRED_MATCHER(name, ...)                                                  \
Paul's avatar
Paul committed
197
198
199
200
201
    struct name##_m                                                                       \
    {                                                                                     \
        bool operator()(__VA_ARGS__) const;                                               \
    };                                                                                    \
    const constexpr auto name =                                                           \
Paul's avatar
Paul committed
202
        migraphx::match::basic_matcher<migraphx::match::predicate_matcher<name##_m>>{{}}; \
Paul's avatar
Paul committed
203
    inline bool name##_m::operator()(__VA_ARGS__) const
Paul's avatar
Paul committed
204
205
206
207
208
209
210

struct matcher_result
{
    std::unordered_map<std::string, instruction_ref> instructions;
    instruction_ref result;
};

Paul's avatar
Paul committed
211
/// Match a single instruction
Paul's avatar
Paul committed
212
template <class M>
213
matcher_result match_instruction(module& p, instruction_ref ins, M&& m)
Paul's avatar
Paul committed
214
215
216
217
{
    assert(ins != p.end());
    matcher_result result;
    matcher_context ctx{p.end()};
Paul's avatar
Paul committed
218
    result.result       = m.match(ctx, ins);
Paul's avatar
Paul committed
219
    result.instructions = ctx.instructions;
Paul's avatar
Paul committed
220
    return result;
Paul's avatar
Paul committed
221
222
}

223
224
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)

Paul's avatar
Paul committed
225
226
/// Find matches for an instruction in the program
template <class... Ms>
227
void find_matches(module& p, instruction_ref ins, Ms&&... ms)
Paul's avatar
Paul committed
228
{
229
230
231
232
233
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
    const
#endif
        bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
    bool match     = false;
Paul's avatar
Paul committed
234
235
236
237
238
239
240
    each_args(
        [&](auto&& m) {
            if(match)
                return;
            auto r = match_instruction(p, ins, m.matcher());
            if(r.result == p.end())
                return;
241
242
243
244
245
            if(trace)
            {
                std::cout << "Matched by " << get_type_name(m) << std::endl;
                p.debug_print(ins);
            }
Paul's avatar
Paul committed
246
247
248
249
250
251
            m.apply(p, r);
            match = true;
        },
        ms...);
}

Paul's avatar
Paul committed
252
/// Find matches in a program
Paul's avatar
Paul committed
253
template <class... Ms>
254
void find_matches(module& p, Ms&&... ms)
Paul's avatar
Paul committed
255
{
Paul's avatar
Paul committed
256
    for(auto ins : iterator_for(p))
Paul's avatar
Paul committed
257
    {
Paul's avatar
Paul committed
258
        find_matches(p, ins, ms...);
Paul's avatar
Paul committed
259
260
261
    }
}

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
template <class M, class F>
struct find_generic_match
{
    M m;
    F f;
    M matcher() const { return m; }

    void apply(module& mod, const matcher_result& mr) const { f(mod, mr); }
};

template <class M, class F>
find_generic_match<M, F> make_match_finder(M m, F f)
{
    return {m, f};
}

Paul's avatar
Paul committed
278
template <class M>
Paul's avatar
Paul committed
279
280
281
struct find_skip
{
    M m;
Paul's avatar
Paul committed
282
    M matcher() const { return m; }
Paul's avatar
Paul committed
283

284
    void apply(module&, const matcher_result&) const {}
Paul's avatar
Paul committed
285
286
};

Paul's avatar
Paul committed
287
template <class M>
Paul's avatar
Paul committed
288
289
290
291
292
find_skip<M> make_find_skip(M m)
{
    return {m};
}

293
294
struct lazy_and
{
Paul's avatar
Paul committed
295
    template <class F, class G>
Paul Fultz II's avatar
Paul Fultz II committed
296
    auto operator()(F f, G g) const
297
    {
Paul Fultz II's avatar
Paul Fultz II committed
298
        return [=] { return f() and g(); };
299
300
301
302
303
    }
};

struct lazy_or
{
Paul's avatar
Paul committed
304
    template <class F, class G>
Paul Fultz II's avatar
Paul Fultz II committed
305
    auto operator()(F f, G g) const
306
    {
Paul Fultz II's avatar
Paul Fultz II committed
307
        return [=] { return f() or g(); };
308
309
310
    }
};

Paul's avatar
Paul committed
311
template <class Op, bool Start, bool Matches>
Paul's avatar
Paul committed
312
struct match_fold_f
Paul's avatar
Paul committed
313
{
Paul's avatar
Paul committed
314
    template <class... Ms>
Paul's avatar
Paul committed
315
    static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
Paul's avatar
Paul committed
316
317
    {
        Op op;
Paul's avatar
Paul committed
318
        auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
Paul Fultz II's avatar
Paul Fultz II committed
319
        return fold(op)(always(Start), matched(ms)...)();
Paul's avatar
Paul committed
320
321
    }

Paul's avatar
Paul committed
322
323
324
    template <class Pack>
    static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
    {
Paul's avatar
Paul committed
325
        return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
Paul's avatar
Paul committed
326
327
    }

Paul's avatar
Paul committed
328
329
330
331
    template <class... Ts>
    auto operator()(Ts... ms) const
    {
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
Paul's avatar
Paul committed
332
            bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
Paul's avatar
Paul committed
333
334
335
336
337
            if(matches == Matches)
                return ins;
            return ctx.not_found();
        });
    }
Paul's avatar
Paul committed
338

Paul's avatar
Paul committed
339
    template <class Selector>
Paul's avatar
Paul committed
340
341
    auto operator[](Selector select) const
    {
Paul's avatar
Paul committed
342
        return [=](auto... ms) {
Paul's avatar
Paul committed
343
            // Workaround ICE on gcc by packing matchers into an object
Paul's avatar
Paul committed
344
            auto mpack = pack(ms...);
Paul's avatar
Paul committed
345
346
347
348
            return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
                Op op;
                bool matches = Start;
                select(start, [&](auto ins) {
Paul's avatar
Paul committed
349
                    auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
Paul Fultz II's avatar
Paul Fultz II committed
350
                    matches = op(always(matches), fm)();
Paul's avatar
Paul committed
351
352
353
354
355
356
357
358
359
                });
                if(matches == Matches)
                    return start;
                return ctx.not_found();
            });
        };
    }
};

Paul's avatar
Paul committed
360
361
362
const constexpr auto all_of  = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of  = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
Paul's avatar
Paul committed
363

Paul's avatar
Paul committed
364
template <class... Ms>
Paul's avatar
Paul committed
365
366
367
368
369
auto skip_matches(Ms... ms)
{
    return make_find_skip(any_of(ms...));
}

Paul's avatar
Paul committed
370
inline auto inputs()
Paul's avatar
Paul committed
371
{
Paul's avatar
Paul committed
372
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
373
        for(auto&& x : ins->inputs())
Paul's avatar
Paul committed
374
375
            f(x);
    };
Paul's avatar
Paul committed
376
377
}

Paul's avatar
Paul committed
378
inline auto outputs()
Paul's avatar
Paul committed
379
{
Paul's avatar
Paul committed
380
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
381
        for(auto&& x : ins->outputs())
Paul's avatar
Paul committed
382
383
            f(x);
    };
Paul's avatar
Paul committed
384
385
}

Paul's avatar
Paul committed
386
387
388
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
Paul's avatar
Paul committed
389
390
391
392
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
    return not ins->get_shape().standard();
}
Paul's avatar
Paul committed
393
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
Paul's avatar
Paul committed
394
395
396
{
    return ins->get_shape().broadcasted();
}
Paul's avatar
Paul committed
397

Paul's avatar
Paul committed
398
399
400
401
402
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
    return ins->get_shape().transposed();
}

Paul's avatar
Paul committed
403
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
Paul's avatar
Paul committed
404
{
Paul's avatar
Paul committed
405
    if(ins->inputs().empty())
Paul's avatar
Paul committed
406
407
        return false;
    auto s = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
408
409
    return std::all_of(
        ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
Paul's avatar
Paul committed
410
411
}

Paul's avatar
Paul committed
412
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Add cbr  
Paul committed
413
414
415
416
417
418
{
    if(ins->outputs().size() == 1)
        return ins->outputs().front();
    return ctx.not_found();
}

Paul's avatar
Paul committed
419
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Paul committed
420
421
422
423
424
425
426
427
{
    if(ins->outputs().size() == 1)
        return ins;
    if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
        return ins;
    return ctx.not_found();
}

Paul's avatar
Paul committed
428
429
inline auto used_once_recursive(std::size_t depth)
{
Paul's avatar
Paul committed
430
    return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
Paul's avatar
Paul committed
431
432
433
434
435
436
437
438
439
440
441
442
443
        // Used once
        if(start->outputs().size() == 1)
            return start;
        // Unused
        if(start->outputs().empty())
        {
            if(std::next(start) == ctx.not_found())
                return start;
            else
                return ctx.not_found();
        }
        // Check for dead instructions
        auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
Paul's avatar
Paul committed
444
            if(n == 0)
Paul's avatar
Paul committed
445
446
447
448
449
450
451
452
453
                return false;
            if(ins->get_shape().elements() == 0)
                return false;
            if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
                return true;
            return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
                return self(i, n - 1);
            });
        });
Paul's avatar
Paul committed
454
        auto dead    = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
Paul's avatar
Paul committed
455
456
            return is_dead(i, depth);
        });
Paul's avatar
Paul committed
457
        if(dead + 1 == start->outputs().size())
Paul's avatar
Paul committed
458
459
460
461
462
            return start;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
463
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
Paul's avatar
Paul committed
464

Paul's avatar
Paul committed
465
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Paul committed
466
{
Paul's avatar
Paul committed
467
    if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
Paul's avatar
Paul committed
468
469
470
471
        return ins;
    return ctx.not_found();
}

472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
template <class... Ms>
auto skip(Ms... ms)
{
    auto m = any_of(ms...);
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
        return fix<instruction_ref>([&](auto self, auto ins) {
            if(ins->inputs().size() == 1 and ctx.matched(m, ins))
            {
                auto next = ins->inputs().front();
                return self(next);
            }
            return ins;
        })(start);
    });
}

Paul's avatar
Paul committed
488
template <class... Ms>
Paul's avatar
Paul committed
489
490
491
492
493
494
495
496
auto skip_output(Ms... ms)
{
    auto m = any_of(ms...);
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
        return fix<instruction_ref>([&](auto self, auto ins) {
            if(ins->outputs().size() == 1)
            {
                auto next = ins->outputs().front();
Paul's avatar
Paul committed
497
                if(ctx.matched(m, next))
Paul's avatar
Paul committed
498
499
                {
                    auto skipped_next = self(next);
Paul's avatar
Paul committed
500
                    if(skipped_next != ctx.not_found())
Paul's avatar
Paul committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                        return skipped_next;
                }
                return next;
            }
            return ctx.not_found();
        })(start);
    });
}

inline auto name(std::string s)
{
    return make_basic_pred_matcher(
        [ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}

Shucai Xiao's avatar
Shucai Xiao committed
516
517
518
519
520
521
inline auto name_contains(const std::string& name)
{
    return make_basic_pred_matcher(
        [=](instruction_ref ins) { return contains(ins->get_operator().name(), name); });
}

Paul's avatar
Paul committed
522
inline auto name(std::unordered_set<std::string> names)
Paul's avatar
Paul committed
523
{
Paul's avatar
Paul committed
524
525
526
    return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
        return names.count(ins->name()) > 0;
    });
Paul's avatar
Paul committed
527
528
}

Paul's avatar
Paul committed
529
template <class... Ts>
Paul's avatar
Paul committed
530
inline auto name(std::string s, Ts... xs) // NOLINT
Paul's avatar
Paul committed
531
{
Paul's avatar
Paul committed
532
    return name(std::unordered_set<std::string>{std::move(s), std::move(xs)...});
Paul's avatar
Paul committed
533
534
}

535
536
inline auto nargs(std::size_t n)
{
Paul's avatar
Paul committed
537
    return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
538
539
}

Paul's avatar
Paul committed
540
541
inline auto arg(std::size_t i)
{
Paul's avatar
Paul committed
542
    return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
Paul's avatar
Paul committed
543
544
545
546
547
548
549
        if(i < ins->inputs().size())
            return ins->inputs()[i];
        return ctx.not_found();
    });
}

// Workaround for bugs in clang
Paul's avatar
Paul committed
550
551
552
553
template <std::size_t...>
struct args_impl_ints
{
};
Paul's avatar
Paul committed
554

Paul's avatar
Paul committed
555
template <std::size_t... Ns, class... Ms>
Paul's avatar
Paul committed
556
557
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
Paul's avatar
Paul committed
558
    return match::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
Paul's avatar
Paul committed
559
560
}

Paul's avatar
Paul committed
561
template <class... Ms>
Paul's avatar
Paul committed
562
563
564
565
566
auto args(Ms... ms)
{
    return sequence_c<sizeof...(Ms)>([=](auto... is) {
        // It needs to be written as `decltype(is)::value` for gcc 5
        return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
Paul's avatar
Paul committed
567
    });
Paul's avatar
Paul committed
568
569
}

Paul's avatar
Paul committed
570
inline auto either_arg(std::size_t i, std::size_t j)
Paul's avatar
Paul committed
571
572
{
    return [=](auto m1, auto m2) {
Paul's avatar
Paul committed
573
574
        return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)),
                             match::all_of(arg(j)(m1), arg(i)(m2)));
Paul's avatar
Paul committed
575
576
577
    };
}

kahmed10's avatar
kahmed10 committed
578
579
580
581
582
inline auto any_arg(std::size_t i, std::size_t j)
{
    return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); };
}

583
584
585
586
587
template <std::size_t N, class M>
std::size_t tree_leafs_impl(matcher_context& ctx,
                            std::array<instruction_ref, N>& leafs,
                            M m,
                            instruction_ref ins)
Paul Fultz II's avatar
Paul Fultz II committed
588
589
590
591
592
{
    std::size_t idx = 0;
    fix([&](auto self, auto i) {
        if(idx == leafs.size())
            return;
593
        if(ctx.matched(m, i) and i->inputs().size() >= 2)
Paul Fultz II's avatar
Paul Fultz II committed
594
595
596
597
598
599
600
601
602
603
604
        {
            self(i->inputs()[0]);
            self(i->inputs()[1]);
            return;
        }
        leafs[idx] = i;
        idx++;
    })(ins);
    return idx;
}

605
606
template <class M, class... Ms>
auto tree(M main_op, Ms... ms)
Paul Fultz II's avatar
Paul Fultz II committed
607
608
609
610
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        // Flatten leaf nodes
        std::array<instruction_ref, sizeof...(Ms)> leafs;
611
        std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
Paul Fultz II's avatar
Paul Fultz II committed
612
613
614
615
616
617
618
619
620
621
622
623
        if(idx != leafs.size())
            return ctx.not_found();
        // Use explicit captures to workaround ICE on gcc
        bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
            return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
        });
        if(not found)
            return ctx.not_found();
        return ins;
    });
}

624
625
template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms)
Paul Fultz II's avatar
Paul Fultz II committed
626
627
628
629
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        // Flatten leaf nodes
        std::array<instruction_ref, sizeof...(Ms)> leafs;
630
        std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
Paul Fultz II's avatar
Paul Fultz II committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        if(idx != leafs.size())
            return ctx.not_found();
        // Use explicit captures to workaround ICE on gcc
        bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
            return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
                return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
            })(ms...)();
        });
        if(not found)
            return ctx.not_found();
        return ins;
    });
}

Paul's avatar
Paul committed
645
template <class M>
Paul's avatar
Paul committed
646
647
648
649
auto same_shape(M m)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        auto i = m.match(ctx, ins);
Paul's avatar
Paul committed
650
        if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
Paul's avatar
Paul committed
651
652
653
654
655
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
656
template <class... Ms>
Paul's avatar
Paul committed
657
658
659
660
661
auto same_shape(Ms... ms)
{
    return all_of(same_shape(ms)...);
}

662
663
664
665
666
667
template <class... Ms>
auto skip_broadcasts(Ms... ms)
{
    return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}

kahmed10's avatar
kahmed10 committed
668
669
670
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
671
    return skip_broadcasts(make_basic_pred_matcher([=](instruction_ref ins) {
kahmed10's avatar
kahmed10 committed
672
        if(ins->name() != "@literal")
kahmed10's avatar
kahmed10 committed
673
674
675
676
677
            return false;
        auto l = ins->get_literal();
        if(l.empty())
            return false;
        bool b = false;
kahmed10's avatar
kahmed10 committed
678
        l.visit([&](auto v) {
Paul Fultz II's avatar
Paul Fultz II committed
679
680
            if(std::all_of(
                   v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; }))
kahmed10's avatar
kahmed10 committed
681
682
                b = true;
        });
kahmed10's avatar
kahmed10 committed
683
        return b;
684
    }));
kahmed10's avatar
kahmed10 committed
685
686
}

687
688
689
690
691
692
inline auto has_attribute(const std::string& name)
{
    return make_basic_pred_matcher(
        [=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}

Paul's avatar
Paul committed
693
} // namespace match
Paul's avatar
Paul committed
694
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
695
} // namespace migraphx
Paul's avatar
Paul committed
696
697

#endif