matcher.hpp 17.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
6
7
8
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
9
#include <migraphx/type_name.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
11
#include <unordered_map>
Paul's avatar
Paul committed
12
#include <unordered_set>
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
namespace migraphx {
Paul's avatar
Paul committed
15
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
namespace match {
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
20
struct matcher_context
{
Paul's avatar
Paul committed
21
    matcher_context(instruction_ref i) : last(i) {}
Paul's avatar
Paul committed
22
    std::unordered_map<std::string, instruction_ref> instructions;
Paul's avatar
Paul committed
23
24
    instruction_ref not_found() const { return last; }

Paul's avatar
Paul committed
25
    template <class M>
Paul's avatar
Paul committed
26
27
28
29
30
    bool matched(M m, instruction_ref ins)
    {
        return m.match(*this, ins) != this->not_found();
    }

Paul Fultz II's avatar
Paul Fultz II committed
31
32
33
34
35
36
    template <class M>
    auto lazy_match(M m, instruction_ref ins)
    {
        return [=] { return this->matched(m, ins); };
    }

Paul's avatar
Paul committed
37
    private:
Paul's avatar
Paul committed
38
    instruction_ref last;
Paul's avatar
Paul committed
39
40
};

Paul's avatar
Paul committed
41
/// Convert a predicate function into a matcher
Paul's avatar
Paul committed
42
template <class P>
Paul's avatar
Paul committed
43
44
45
46
struct predicate_matcher
{
    P p;

Paul Fultz II's avatar
Paul Fultz II committed
47
    instruction_ref match(const matcher_context& ctx, instruction_ref ins) const
Paul's avatar
Paul committed
48
49
50
51
52
53
54
55
    {
        assert(ins != ctx.not_found());
        if(p(ins))
            return ins;
        return ctx.not_found();
    }
};

Paul's avatar
Paul committed
56
/// Convert a function into a matcher
Paul's avatar
Paul committed
57
template <class F>
Paul's avatar
Paul committed
58
59
60
61
62
63
64
65
66
67
68
struct function_matcher
{
    F f;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        return f(ctx, ins);
    }
};

Paul's avatar
Paul committed
69
/// Convert a function into a matcher
Paul's avatar
Paul committed
70
template <class F>
Paul's avatar
Paul committed
71
72
73
74
75
function_matcher<F> make_function_matcher(F f)
{
    return {f};
}

Paul's avatar
Paul committed
76
/// Converts a matcher to bind the instruction to name
Paul's avatar
Paul committed
77
template <class M>
Paul's avatar
Paul committed
78
79
auto bind_match(M m, std::string name)
{
Paul's avatar
Paul committed
80
81
82
83
    return make_function_matcher(
        [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
            auto result = m.match(ctx, ins);
            if(result != ctx.not_found())
Paul's avatar
Paul committed
84
                ctx.instructions[name] = ins;
Paul's avatar
Paul committed
85
86
            return result;
        });
Paul's avatar
Paul committed
87
88
}

Paul's avatar
Paul committed
89
/// Convert a matcher to a bindable matcher
Paul's avatar
Paul committed
90
template <class M>
Paul's avatar
Paul committed
91
92
93
94
struct bindable_matcher
{
    M m;

Paul's avatar
Paul committed
95
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
96
97
98
99
100
101
102

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
103
/// Create a bindable matcher
Paul's avatar
Paul committed
104
template <class M>
Paul's avatar
Paul committed
105
106
107
108
109
bindable_matcher<M> make_bindable_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
110
/// Create a bindable matcher from a function
Paul's avatar
Paul committed
111
template <class F>
Paul's avatar
Paul committed
112
113
114
115
116
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
117
/// Create a bindable matcher from a predicate function
Paul's avatar
Paul committed
118
template <class F>
Paul's avatar
Paul committed
119
120
121
122
123
124
125
126
127
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
    return {{f}};
}

using bool_list = std::initializer_list<bool>;

struct id_matcher
{
Paul's avatar
Paul committed
128
    instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
Paul's avatar
Paul committed
129
130
};

Paul's avatar
Paul committed
131
/// The basic matcher provides the all_of composability of the matcher
Paul's avatar
Paul committed
132
template <class M>
Paul's avatar
Paul committed
133
134
135
136
struct basic_matcher
{
    M m;

Paul's avatar
Paul committed
137
    template <class... Ts>
Paul's avatar
Paul committed
138
139
140
141
142
143
    auto operator()(Ts... ms) const
    {
        // Copy m because we cant capture `this` by value
        auto mm = m;
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
            auto result = mm.match(ctx, ins);
Paul's avatar
Paul committed
144
            if(result != ctx.not_found())
Paul's avatar
Paul committed
145
146
147
148
149
150
151
152
153
154
155
            {
                bool matches = fold([&](auto x, auto y) {
                    return x and y.match(ctx, result) != ctx.not_found();
                })(true, ms...);
                if(matches)
                    return result;
            }
            return ctx.not_found();
        });
    }

Paul's avatar
Paul committed
156
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
157
158
159
160
161
162
163

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
164
/// Create a basic matcher from a matcher
Paul's avatar
Paul committed
165
template <class M>
Paul's avatar
Paul committed
166
167
168
169
170
basic_matcher<M> make_basic_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
171
/// Create a basic matcher from a function
Paul's avatar
Paul committed
172
template <class F>
Paul's avatar
Paul committed
173
174
175
176
177
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
178
/// Create a basic matcher from a predicate function
Paul's avatar
Paul committed
179
template <class P>
Paul's avatar
Paul committed
180
181
182
183
184
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
    return {{p}};
}

Paul's avatar
Paul committed
185
/// This macro takes care of the boilerplate for defining a matcher
Paul's avatar
Paul committed
186
#define MIGRAPHX_BASIC_MATCHER(name, ...)                                     \
Paul's avatar
Paul committed
187
188
189
190
    struct name##_m                                                           \
    {                                                                         \
        instruction_ref match(__VA_ARGS__) const;                             \
    };                                                                        \
Paul's avatar
Paul committed
191
    const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
Paul's avatar
Paul committed
192
193
    inline instruction_ref name##_m::match(__VA_ARGS__) const

Paul's avatar
Paul committed
194
/// This macro takes care of the boilerplate for defining a predicate matcher
Paul's avatar
Paul committed
195
#define MIGRAPHX_PRED_MATCHER(name, ...)                                                  \
Paul's avatar
Paul committed
196
197
198
199
200
    struct name##_m                                                                       \
    {                                                                                     \
        bool operator()(__VA_ARGS__) const;                                               \
    };                                                                                    \
    const constexpr auto name =                                                           \
Paul's avatar
Paul committed
201
        migraphx::match::basic_matcher<migraphx::match::predicate_matcher<name##_m>>{{}}; \
Paul's avatar
Paul committed
202
    inline bool name##_m::operator()(__VA_ARGS__) const
Paul's avatar
Paul committed
203
204
205
206
207
208
209

struct matcher_result
{
    std::unordered_map<std::string, instruction_ref> instructions;
    instruction_ref result;
};

Paul's avatar
Paul committed
210
/// Match a single instruction
Paul's avatar
Paul committed
211
template <class M>
212
matcher_result match_instruction(module& p, instruction_ref ins, M&& m)
Paul's avatar
Paul committed
213
214
215
216
{
    assert(ins != p.end());
    matcher_result result;
    matcher_context ctx{p.end()};
Paul's avatar
Paul committed
217
    result.result       = m.match(ctx, ins);
Paul's avatar
Paul committed
218
    result.instructions = ctx.instructions;
Paul's avatar
Paul committed
219
    return result;
Paul's avatar
Paul committed
220
221
}

222
223
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)

Paul's avatar
Paul committed
224
225
/// Find matches for an instruction in the program
template <class... Ms>
226
void find_matches(module& p, instruction_ref ins, Ms&&... ms)
Paul's avatar
Paul committed
227
{
228
229
230
231
232
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
    const
#endif
        bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
    bool match     = false;
Paul's avatar
Paul committed
233
234
235
236
237
238
239
    each_args(
        [&](auto&& m) {
            if(match)
                return;
            auto r = match_instruction(p, ins, m.matcher());
            if(r.result == p.end())
                return;
240
241
242
243
244
            if(trace)
            {
                std::cout << "Matched by " << get_type_name(m) << std::endl;
                p.debug_print(ins);
            }
Paul's avatar
Paul committed
245
246
247
248
249
250
            m.apply(p, r);
            match = true;
        },
        ms...);
}

Paul's avatar
Paul committed
251
/// Find matches in a program
Paul's avatar
Paul committed
252
template <class... Ms>
253
void find_matches(module& p, Ms&&... ms)
Paul's avatar
Paul committed
254
{
Paul's avatar
Paul committed
255
    for(auto ins : iterator_for(p))
Paul's avatar
Paul committed
256
    {
Paul's avatar
Paul committed
257
        find_matches(p, ins, ms...);
Paul's avatar
Paul committed
258
259
260
    }
}

Paul's avatar
Paul committed
261
template <class M>
Paul's avatar
Paul committed
262
263
264
struct find_skip
{
    M m;
Paul's avatar
Paul committed
265
    M matcher() const { return m; }
Paul's avatar
Paul committed
266

267
    void apply(module&, const matcher_result&) const {}
Paul's avatar
Paul committed
268
269
};

Paul's avatar
Paul committed
270
template <class M>
Paul's avatar
Paul committed
271
272
273
274
275
find_skip<M> make_find_skip(M m)
{
    return {m};
}

276
277
struct lazy_and
{
Paul's avatar
Paul committed
278
    template <class F, class G>
Paul Fultz II's avatar
Paul Fultz II committed
279
    auto operator()(F f, G g) const
280
    {
Paul Fultz II's avatar
Paul Fultz II committed
281
        return [=] { return f() and g(); };
282
283
284
285
286
    }
};

struct lazy_or
{
Paul's avatar
Paul committed
287
    template <class F, class G>
Paul Fultz II's avatar
Paul Fultz II committed
288
    auto operator()(F f, G g) const
289
    {
Paul Fultz II's avatar
Paul Fultz II committed
290
        return [=] { return f() or g(); };
291
292
293
    }
};

Paul's avatar
Paul committed
294
template <class Op, bool Start, bool Matches>
Paul's avatar
Paul committed
295
struct match_fold_f
Paul's avatar
Paul committed
296
{
Paul's avatar
Paul committed
297
    template <class... Ms>
Paul's avatar
Paul committed
298
    static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
Paul's avatar
Paul committed
299
300
    {
        Op op;
Paul's avatar
Paul committed
301
        auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
Paul Fultz II's avatar
Paul Fultz II committed
302
        return fold(op)(always(Start), matched(ms)...)();
Paul's avatar
Paul committed
303
304
    }

Paul's avatar
Paul committed
305
306
307
    template <class Pack>
    static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
    {
Paul's avatar
Paul committed
308
        return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
Paul's avatar
Paul committed
309
310
    }

Paul's avatar
Paul committed
311
312
313
314
    template <class... Ts>
    auto operator()(Ts... ms) const
    {
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
Paul's avatar
Paul committed
315
            bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
Paul's avatar
Paul committed
316
317
318
319
320
            if(matches == Matches)
                return ins;
            return ctx.not_found();
        });
    }
Paul's avatar
Paul committed
321

Paul's avatar
Paul committed
322
    template <class Selector>
Paul's avatar
Paul committed
323
324
    auto operator[](Selector select) const
    {
Paul's avatar
Paul committed
325
        return [=](auto... ms) {
Paul's avatar
Paul committed
326
            // Workaround ICE on gcc by packing matchers into an object
Paul's avatar
Paul committed
327
            auto mpack = pack(ms...);
Paul's avatar
Paul committed
328
329
330
331
            return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
                Op op;
                bool matches = Start;
                select(start, [&](auto ins) {
Paul's avatar
Paul committed
332
                    auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
Paul Fultz II's avatar
Paul Fultz II committed
333
                    matches = op(always(matches), fm)();
Paul's avatar
Paul committed
334
335
336
337
338
339
340
341
342
                });
                if(matches == Matches)
                    return start;
                return ctx.not_found();
            });
        };
    }
};

Paul's avatar
Paul committed
343
344
345
const constexpr auto all_of  = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of  = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
Paul's avatar
Paul committed
346

Paul's avatar
Paul committed
347
template <class... Ms>
Paul's avatar
Paul committed
348
349
350
351
352
auto skip_matches(Ms... ms)
{
    return make_find_skip(any_of(ms...));
}

Paul's avatar
Paul committed
353
inline auto inputs()
Paul's avatar
Paul committed
354
{
Paul's avatar
Paul committed
355
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
356
        for(auto&& x : ins->inputs())
Paul's avatar
Paul committed
357
358
            f(x);
    };
Paul's avatar
Paul committed
359
360
}

Paul's avatar
Paul committed
361
inline auto outputs()
Paul's avatar
Paul committed
362
{
Paul's avatar
Paul committed
363
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
364
        for(auto&& x : ins->outputs())
Paul's avatar
Paul committed
365
366
            f(x);
    };
Paul's avatar
Paul committed
367
368
}

Paul's avatar
Paul committed
369
370
371
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
Paul's avatar
Paul committed
372
373
374
375
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
    return not ins->get_shape().standard();
}
Paul's avatar
Paul committed
376
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
Paul's avatar
Paul committed
377
378
379
{
    return ins->get_shape().broadcasted();
}
Paul's avatar
Paul committed
380

Paul's avatar
Paul committed
381
382
383
384
385
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
    return ins->get_shape().transposed();
}

Paul's avatar
Paul committed
386
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
Paul's avatar
Paul committed
387
{
Paul's avatar
Paul committed
388
    if(ins->inputs().empty())
Paul's avatar
Paul committed
389
390
        return false;
    auto s = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
391
392
    return std::all_of(
        ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
Paul's avatar
Paul committed
393
394
}

Paul's avatar
Paul committed
395
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Add cbr  
Paul committed
396
397
398
399
400
401
{
    if(ins->outputs().size() == 1)
        return ins->outputs().front();
    return ctx.not_found();
}

Paul's avatar
Paul committed
402
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Paul committed
403
404
405
406
407
408
409
410
{
    if(ins->outputs().size() == 1)
        return ins;
    if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
        return ins;
    return ctx.not_found();
}

Paul's avatar
Paul committed
411
412
inline auto used_once_recursive(std::size_t depth)
{
Paul's avatar
Paul committed
413
    return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
Paul's avatar
Paul committed
414
415
416
417
418
419
420
421
422
423
424
425
426
        // Used once
        if(start->outputs().size() == 1)
            return start;
        // Unused
        if(start->outputs().empty())
        {
            if(std::next(start) == ctx.not_found())
                return start;
            else
                return ctx.not_found();
        }
        // Check for dead instructions
        auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
Paul's avatar
Paul committed
427
            if(n == 0)
Paul's avatar
Paul committed
428
429
430
431
432
433
434
435
436
                return false;
            if(ins->get_shape().elements() == 0)
                return false;
            if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
                return true;
            return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
                return self(i, n - 1);
            });
        });
Paul's avatar
Paul committed
437
        auto dead    = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
Paul's avatar
Paul committed
438
439
            return is_dead(i, depth);
        });
Paul's avatar
Paul committed
440
        if(dead + 1 == start->outputs().size())
Paul's avatar
Paul committed
441
442
443
444
445
            return start;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
446
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
Paul's avatar
Paul committed
447

Paul's avatar
Paul committed
448
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Paul committed
449
{
Paul's avatar
Paul committed
450
    if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
Paul's avatar
Paul committed
451
452
453
454
        return ins;
    return ctx.not_found();
}

Paul's avatar
Paul committed
455
template <class... Ms>
Paul's avatar
Paul committed
456
457
458
459
460
461
462
463
auto skip_output(Ms... ms)
{
    auto m = any_of(ms...);
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
        return fix<instruction_ref>([&](auto self, auto ins) {
            if(ins->outputs().size() == 1)
            {
                auto next = ins->outputs().front();
Paul's avatar
Paul committed
464
                if(ctx.matched(m, next))
Paul's avatar
Paul committed
465
466
                {
                    auto skipped_next = self(next);
Paul's avatar
Paul committed
467
                    if(skipped_next != ctx.not_found())
Paul's avatar
Paul committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
                        return skipped_next;
                }
                return next;
            }
            return ctx.not_found();
        })(start);
    });
}

inline auto name(std::string s)
{
    return make_basic_pred_matcher(
        [ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}

inline auto name(std::unordered_set<std::string> names)
Paul's avatar
Paul committed
484
{
Paul's avatar
Paul committed
485
486
487
    return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
        return names.count(ins->name()) > 0;
    });
Paul's avatar
Paul committed
488
489
}

Paul's avatar
Paul committed
490
template <class... Ts>
Paul's avatar
Paul committed
491
inline auto name(std::string s, Ts... xs) // NOLINT
Paul's avatar
Paul committed
492
{
Paul's avatar
Paul committed
493
    return name(std::unordered_set<std::string>{std::move(s), std::move(xs)...});
Paul's avatar
Paul committed
494
495
}

496
497
inline auto nargs(std::size_t n)
{
Paul's avatar
Paul committed
498
    return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
499
500
}

Paul's avatar
Paul committed
501
502
inline auto arg(std::size_t i)
{
Paul's avatar
Paul committed
503
    return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
Paul's avatar
Paul committed
504
505
506
507
508
509
510
        if(i < ins->inputs().size())
            return ins->inputs()[i];
        return ctx.not_found();
    });
}

// Workaround for bugs in clang
Paul's avatar
Paul committed
511
512
513
514
template <std::size_t...>
struct args_impl_ints
{
};
Paul's avatar
Paul committed
515

Paul's avatar
Paul committed
516
template <std::size_t... Ns, class... Ms>
Paul's avatar
Paul committed
517
518
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
Paul's avatar
Paul committed
519
    return match::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
Paul's avatar
Paul committed
520
521
}

Paul's avatar
Paul committed
522
template <class... Ms>
Paul's avatar
Paul committed
523
524
525
526
527
auto args(Ms... ms)
{
    return sequence_c<sizeof...(Ms)>([=](auto... is) {
        // It needs to be written as `decltype(is)::value` for gcc 5
        return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
Paul's avatar
Paul committed
528
    });
Paul's avatar
Paul committed
529
530
}

Paul's avatar
Paul committed
531
inline auto either_arg(std::size_t i, std::size_t j)
Paul's avatar
Paul committed
532
533
{
    return [=](auto m1, auto m2) {
Paul's avatar
Paul committed
534
535
        return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)),
                             match::all_of(arg(j)(m1), arg(i)(m2)));
Paul's avatar
Paul committed
536
537
538
    };
}

kahmed10's avatar
kahmed10 committed
539
540
541
542
543
inline auto any_arg(std::size_t i, std::size_t j)
{
    return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); };
}

Paul Fultz II's avatar
Paul Fultz II committed
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
template <std::size_t N>
std::size_t
tree_leafs_impl(std::array<instruction_ref, N>& leafs, const std::string& s, instruction_ref ins)
{
    std::size_t idx = 0;
    fix([&](auto self, auto i) {
        if(idx == leafs.size())
            return;
        if(i->name() == s and i->inputs().size() >= 2)
        {
            self(i->inputs()[0]);
            self(i->inputs()[1]);
            return;
        }
        leafs[idx] = i;
        idx++;
    })(ins);
    return idx;
}

template <class... Ms>
auto tree(std::string s, Ms... ms)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        // Flatten leaf nodes
        std::array<instruction_ref, sizeof...(Ms)> leafs;
        std::size_t idx = tree_leafs_impl(leafs, s, ins);
        if(idx != leafs.size())
            return ctx.not_found();
        // Use explicit captures to workaround ICE on gcc
        bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
            return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
        });
        if(not found)
            return ctx.not_found();
        return ins;
    });
}

template <class... Ms>
auto unordered_tree(std::string s, Ms... ms)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        // Flatten leaf nodes
        std::array<instruction_ref, sizeof...(Ms)> leafs;
        std::size_t idx = tree_leafs_impl(leafs, s, ins);
        if(idx != leafs.size())
            return ctx.not_found();
        // Use explicit captures to workaround ICE on gcc
        bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
            return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
                return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
            })(ms...)();
        });
        if(not found)
            return ctx.not_found();
        return ins;
    });
}

Paul's avatar
Paul committed
604
template <class M>
Paul's avatar
Paul committed
605
606
607
608
auto same_shape(M m)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        auto i = m.match(ctx, ins);
Paul's avatar
Paul committed
609
        if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
Paul's avatar
Paul committed
610
611
612
613
614
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
615
template <class... Ms>
Paul's avatar
Paul committed
616
617
618
619
620
auto same_shape(Ms... ms)
{
    return all_of(same_shape(ms)...);
}

kahmed10's avatar
kahmed10 committed
621
622
623
624
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
    return make_basic_pred_matcher([=](instruction_ref ins) {
kahmed10's avatar
kahmed10 committed
625
        if(ins->name() != "@literal")
kahmed10's avatar
kahmed10 committed
626
627
628
629
630
            return false;
        auto l = ins->get_literal();
        if(l.empty())
            return false;
        bool b = false;
kahmed10's avatar
kahmed10 committed
631
        l.visit([&](auto v) {
Paul Fultz II's avatar
Paul Fultz II committed
632
633
            if(std::all_of(
                   v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; }))
kahmed10's avatar
kahmed10 committed
634
635
                b = true;
        });
kahmed10's avatar
kahmed10 committed
636
637
638
639
        return b;
    });
}

640
641
642
643
644
645
inline auto has_attribute(const std::string& name)
{
    return make_basic_pred_matcher(
        [=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}

Paul's avatar
Paul committed
646
} // namespace match
Paul's avatar
Paul committed
647
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
648
} // namespace migraphx
Paul's avatar
Paul committed
649
650

#endif