matcher.hpp 23.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24
25
#ifndef MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
Paul's avatar
Paul committed
26

Paul's avatar
Paul committed
27
28
29
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
30
#include <migraphx/module.hpp>
31
#include <migraphx/optional.hpp>
Paul's avatar
Paul committed
32
#include <migraphx/iterator_for.hpp>
33
#include <migraphx/type_name.hpp>
Paul's avatar
Paul committed
34
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
35
#include <unordered_map>
Paul's avatar
Paul committed
36
#include <unordered_set>
Paul's avatar
Paul committed
37

Paul's avatar
Paul committed
38
namespace migraphx {
Paul's avatar
Paul committed
39
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
40

Paul's avatar
Paul committed
41
namespace match {
Paul's avatar
Paul committed
42

Paul's avatar
Paul committed
43
44
struct matcher_context
{
45
    matcher_context(module& m) : mod(&m) {}
Paul's avatar
Paul committed
46
    std::unordered_map<std::string, instruction_ref> instructions;
Paul's avatar
Paul committed
47

Paul's avatar
Paul committed
48
    template <class M>
Paul's avatar
Paul committed
49
50
    bool matched(M m, instruction_ref ins)
    {
51
        return has_value(m.match(*this, ins));
Paul's avatar
Paul committed
52
53
    }

Paul Fultz II's avatar
Paul Fultz II committed
54
    template <class M>
55
56
57
58
59
60
61
62
63
    bool matched(M m, optional<instruction_ref> ins)
    {
        if(ins)
            return has_value(m.match(*this, *ins));
        return false;
    }

    template <class M, class I>
    auto lazy_match(M m, I ins)
Paul Fultz II's avatar
Paul Fultz II committed
64
65
66
67
    {
        return [=] { return this->matched(m, ins); };
    }

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    bool has_instruction(instruction_ref ins) const
    {
        if(mod == nullptr)
            return true;
        return mod->has_instruction(ins);
    }
    bool has_instruction(optional<instruction_ref> ins) const
    {
        if(ins)
            return this->has_instruction(*ins);
        return false;
    }

    bool is_last(instruction_ref ins) const
    {
        assert(mod->begin() != mod->end());
        assert(this->has_instruction(ins));
        return ins == std::prev(mod->end());
    }

Paul's avatar
Paul committed
88
    private:
89
    module* mod = nullptr;
Paul's avatar
Paul committed
90
91
};

Paul's avatar
Paul committed
92
/// Convert a predicate function into a matcher
Paul's avatar
Paul committed
93
template <class P>
Paul's avatar
Paul committed
94
95
96
97
struct predicate_matcher
{
    P p;

98
    optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
Paul's avatar
Paul committed
99
100
    {
        if(p(ins))
101
102
            return optional<instruction_ref>{ins};
        return nullopt;
Paul's avatar
Paul committed
103
104
105
    }
};

Paul's avatar
Paul committed
106
/// Convert a function into a matcher
Paul's avatar
Paul committed
107
template <class F>
Paul's avatar
Paul committed
108
109
110
111
struct function_matcher
{
    F f;

112
    auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
Paul's avatar
Paul committed
113
114
};

Paul's avatar
Paul committed
115
/// Convert a function into a matcher
Paul's avatar
Paul committed
116
template <class F>
Paul's avatar
Paul committed
117
118
119
120
121
function_matcher<F> make_function_matcher(F f)
{
    return {f};
}

Paul's avatar
Paul committed
122
/// Converts a matcher to bind the instruction to name
Paul's avatar
Paul committed
123
template <class M>
Paul's avatar
Paul committed
124
125
auto bind_match(M m, std::string name)
{
Paul's avatar
Paul committed
126
    return make_function_matcher(
bpickrel's avatar
bpickrel committed
127
128
129
130
131
132
133
134
135
136
137
        [=, name = std::move(name)](matcher_context& ctx,
                                    instruction_ref ins) -> optional<instruction_ref> {
            auto result = m.match(ctx, ins);
            if(result)
            {
                if(not ctx.has_instruction(ins))
                    return nullopt;
                ctx.instructions[name] = ins;
            }
            return result;
        });
Paul's avatar
Paul committed
138
139
}

Paul's avatar
Paul committed
140
/// Convert a matcher to a bindable matcher
Paul's avatar
Paul committed
141
template <class M>
Paul's avatar
Paul committed
142
143
144
145
struct bindable_matcher
{
    M m;

Paul's avatar
Paul committed
146
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
147

148
    auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
Paul's avatar
Paul committed
149
150
};

Paul's avatar
Paul committed
151
/// Create a bindable matcher
Paul's avatar
Paul committed
152
template <class M>
Paul's avatar
Paul committed
153
154
155
156
157
bindable_matcher<M> make_bindable_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
158
/// Create a bindable matcher from a function
Paul's avatar
Paul committed
159
template <class F>
Paul's avatar
Paul committed
160
161
162
163
164
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
165
/// Create a bindable matcher from a predicate function
Paul's avatar
Paul committed
166
template <class F>
Paul's avatar
Paul committed
167
168
169
170
171
172
173
174
175
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
    return {{f}};
}

using bool_list = std::initializer_list<bool>;

struct id_matcher
{
176
177
178
179
    auto match(matcher_context&, instruction_ref ins) const
    {
        return optional<instruction_ref>{ins};
    }
Paul's avatar
Paul committed
180
181
};

182
183
184
185
186
187
188
189
190
191
192
193
194
// Forward declare class and constructors
template <class M>
struct basic_matcher;

template <class M>
basic_matcher<M> make_basic_matcher(M m);

template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);

template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);

Paul's avatar
Paul committed
195
/// The basic matcher provides the all_of composability of the matcher
Paul's avatar
Paul committed
196
template <class M>
Paul's avatar
Paul committed
197
198
199
200
struct basic_matcher
{
    M m;

Paul's avatar
Paul committed
201
    template <class... Ts>
Paul's avatar
Paul committed
202
203
204
205
    auto operator()(Ts... ms) const
    {
        // Copy m because we cant capture `this` by value
        auto mm = m;
206
207
        return make_basic_fun_matcher([=](matcher_context& ctx,
                                          instruction_ref ins) -> optional<instruction_ref> {
Paul's avatar
Paul committed
208
            auto result = mm.match(ctx, ins);
209
            if(result)
Paul's avatar
Paul committed
210
            {
211
212
                bool matches =
                    fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
Paul's avatar
Paul committed
213
214
215
                if(matches)
                    return result;
            }
216
            return nullopt;
Paul's avatar
Paul committed
217
218
219
        });
    }

Paul's avatar
Paul committed
220
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
221

222
    auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
Paul's avatar
Paul committed
223
224
};

Paul's avatar
Paul committed
225
/// Create a basic matcher from a matcher
Paul's avatar
Paul committed
226
template <class M>
Paul's avatar
Paul committed
227
228
229
230
231
basic_matcher<M> make_basic_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
232
/// Create a basic matcher from a function
Paul's avatar
Paul committed
233
template <class F>
Paul's avatar
Paul committed
234
235
236
237
238
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
239
/// Create a basic matcher from a predicate function
Paul's avatar
Paul committed
240
template <class P>
Paul's avatar
Paul committed
241
242
243
244
245
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
    return {{p}};
}

246
247
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
248
    function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
249
250
251
252
253
254
255
256
struct any_matcher : any_matcher_base
{
    template <class M>
    any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
    {
    }
};

Paul's avatar
Paul committed
257
/// This macro takes care of the boilerplate for defining a matcher
Paul's avatar
Paul committed
258
#define MIGRAPHX_BASIC_MATCHER(name, ...)                                     \
Paul's avatar
Paul committed
259
260
    struct name##_m                                                           \
    {                                                                         \
261
        optional<instruction_ref> match(__VA_ARGS__) const;                   \
Paul's avatar
Paul committed
262
    };                                                                        \
Paul's avatar
Paul committed
263
    const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
264
    inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
Paul's avatar
Paul committed
265

Paul's avatar
Paul committed
266
/// This macro takes care of the boilerplate for defining a predicate matcher
Paul's avatar
Paul committed
267
#define MIGRAPHX_PRED_MATCHER(name, ...)                                                  \
Paul's avatar
Paul committed
268
269
270
271
272
    struct name##_m                                                                       \
    {                                                                                     \
        bool operator()(__VA_ARGS__) const;                                               \
    };                                                                                    \
    const constexpr auto name =                                                           \
Paul's avatar
Paul committed
273
        migraphx::match::basic_matcher<migraphx::match::predicate_matcher<name##_m>>{{}}; \
Paul's avatar
Paul committed
274
    inline bool name##_m::operator()(__VA_ARGS__) const
Paul's avatar
Paul committed
275
276
277

struct matcher_result
{
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    struct instruction_container
    {
        instruction_container() = default;
        instruction_container(std::unordered_map<std::string, instruction_ref> x)
            : ins_map(std::move(x))
        {
        }

        instruction_ref operator[](const std::string& name) const
        {
            auto it = ins_map.find(name);
            if(it == ins_map.end())
                MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
            return it->second;
        }

        auto find(const std::string& name) const { return ins_map.find(name); }

        auto begin() const { return ins_map.cbegin(); }

        auto end() const { return ins_map.cend(); }

        bool has_instructions_in(const module& mod) const
        {
            return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
                return mod.has_instruction(p.second);
            });
        }

        private:
        std::unordered_map<std::string, instruction_ref> ins_map;
    };
    instruction_container instructions;
Paul's avatar
Paul committed
311
312
313
    instruction_ref result;
};

Paul's avatar
Paul committed
314
/// Match a single instruction
Paul's avatar
Paul committed
315
template <class M>
316
matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
Paul's avatar
Paul committed
317
{
318
319
320
    assert(ins != mod.end());
    assert(mod.has_instruction(ins));
    matcher_context ctx{mod};
Paul's avatar
Paul committed
321
    matcher_result result;
322
323
324
325
    if(m.match(ctx, ins))
    {
        result.result       = ins;
        result.instructions = ctx.instructions;
326
        assert(result.instructions.has_instructions_in(mod));
327
328
329
330
331
    }
    else
    {
        result.result = mod.end();
    }
Paul's avatar
Paul committed
332
    return result;
Paul's avatar
Paul committed
333
334
}

turneram's avatar
turneram committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
/// Find first instance of a matching instruction in a module
template <class M>
match::matcher_result find_match(module& modl, M&& m)
{
    match::matcher_result result;
    for(auto ins : iterator_for(modl))
    {
        result = match::match_instruction(modl, ins, m);
        if(result.result != modl.end())
            return result;
    }
    return result;
}

349
350
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)

351
/// Find matches for an instruction in the module
Paul's avatar
Paul committed
352
template <class... Ms>
353
void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
Paul's avatar
Paul committed
354
{
355
356
357
358
359
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
    const
#endif
        bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
    bool match     = false;
Paul's avatar
Paul committed
360
361
362
363
    each_args(
        [&](auto&& m) {
            if(match)
                return;
364
365
            auto r = match_instruction(mod, ins, m.matcher());
            if(r.result == mod.end())
Paul's avatar
Paul committed
366
                return;
367
368
369
            if(trace)
            {
                std::cout << "Matched by " << get_type_name(m) << std::endl;
370
                mod.debug_print(ins);
371
            }
372
            m.apply(mod, r);
Paul's avatar
Paul committed
373
374
375
376
377
            match = true;
        },
        ms...);
}

378
/// Find matches in a module
Paul's avatar
Paul committed
379
template <class... Ms>
380
void find_matches(module& mod, Ms&&... ms)
Paul's avatar
Paul committed
381
{
382
    for(auto ins : iterator_for(mod))
Paul's avatar
Paul committed
383
    {
384
        find_matches(mod, ins, ms...);
Paul's avatar
Paul committed
385
386
387
    }
}

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
template <class M, class F>
struct find_generic_match
{
    M m;
    F f;
    M matcher() const { return m; }

    void apply(module& mod, const matcher_result& mr) const { f(mod, mr); }
};

template <class M, class F>
find_generic_match<M, F> make_match_finder(M m, F f)
{
    return {m, f};
}

Paul's avatar
Paul committed
404
template <class M>
Paul's avatar
Paul committed
405
406
407
struct find_skip
{
    M m;
Paul's avatar
Paul committed
408
    M matcher() const { return m; }
Paul's avatar
Paul committed
409

410
    void apply(module&, const matcher_result&) const {}
Paul's avatar
Paul committed
411
412
};

Paul's avatar
Paul committed
413
template <class M>
Paul's avatar
Paul committed
414
415
416
417
418
find_skip<M> make_find_skip(M m)
{
    return {m};
}

419
420
struct lazy_and
{
Paul's avatar
Paul committed
421
    template <class F, class G>
Paul Fultz II's avatar
Paul Fultz II committed
422
    auto operator()(F f, G g) const
423
    {
Paul Fultz II's avatar
Paul Fultz II committed
424
        return [=] { return f() and g(); };
425
426
427
428
429
    }
};

struct lazy_or
{
Paul's avatar
Paul committed
430
    template <class F, class G>
Paul Fultz II's avatar
Paul Fultz II committed
431
    auto operator()(F f, G g) const
432
    {
Paul Fultz II's avatar
Paul Fultz II committed
433
        return [=] { return f() or g(); };
434
435
436
    }
};

Paul's avatar
Paul committed
437
template <class Op, bool Start, bool Matches>
Paul's avatar
Paul committed
438
struct match_fold_f
Paul's avatar
Paul committed
439
{
Paul's avatar
Paul committed
440
    template <class... Ms>
Paul's avatar
Paul committed
441
    static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
Paul's avatar
Paul committed
442
443
    {
        Op op;
Paul's avatar
Paul committed
444
        auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
Paul Fultz II's avatar
Paul Fultz II committed
445
        return fold(op)(always(Start), matched(ms)...)();
Paul's avatar
Paul committed
446
447
    }

Paul's avatar
Paul committed
448
449
450
    template <class Pack>
    static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
    {
Paul's avatar
Paul committed
451
        return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
Paul's avatar
Paul committed
452
453
    }

Paul's avatar
Paul committed
454
455
456
    template <class... Ts>
    auto operator()(Ts... ms) const
    {
457
458
459
460
461
462
463
        return make_bf_matcher(
            [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
                bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
                if(matches == Matches)
                    return {ins};
                return nullopt;
            });
Paul's avatar
Paul committed
464
    }
Paul's avatar
Paul committed
465

Paul's avatar
Paul committed
466
    template <class Selector>
Paul's avatar
Paul committed
467
468
    auto operator[](Selector select) const
    {
Paul's avatar
Paul committed
469
        return [=](auto... ms) {
Paul's avatar
Paul committed
470
            // Workaround ICE on gcc by packing matchers into an object
Paul's avatar
Paul committed
471
            auto mpack = pack(ms...);
472
473
474
475
476
477
478
479
480
481
482
            return make_bf_matcher(
                [=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
                    Op op;
                    bool matches = Start;
                    select(start, [&](auto ins) {
                        auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
                        matches = op(always(matches), fm)();
                    });
                    if(matches == Matches)
                        return {start};
                    return nullopt;
Paul's avatar
Paul committed
483
484
485
486
487
                });
        };
    }
};

Paul's avatar
Paul committed
488
489
490
const constexpr auto all_of  = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of  = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
Paul's avatar
Paul committed
491

Paul's avatar
Paul committed
492
template <class... Ms>
Paul's avatar
Paul committed
493
494
495
496
497
auto skip_matches(Ms... ms)
{
    return make_find_skip(any_of(ms...));
}

Paul's avatar
Paul committed
498
inline auto inputs()
Paul's avatar
Paul committed
499
{
Paul's avatar
Paul committed
500
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
501
        for(auto&& x : ins->inputs())
Paul's avatar
Paul committed
502
503
            f(x);
    };
Paul's avatar
Paul committed
504
505
}

Paul's avatar
Paul committed
506
inline auto outputs()
Paul's avatar
Paul committed
507
{
Paul's avatar
Paul committed
508
    return [](auto ins, auto f) {
Paul's avatar
Paul committed
509
        for(auto&& x : ins->outputs())
Paul's avatar
Paul committed
510
511
            f(x);
    };
Paul's avatar
Paul committed
512
513
}

Paul's avatar
Paul committed
514
515
516
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
Paul's avatar
Paul committed
517
518
519
520
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
    return not ins->get_shape().standard();
}
Paul's avatar
Paul committed
521
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
Paul's avatar
Paul committed
522
523
524
{
    return ins->get_shape().broadcasted();
}
Paul's avatar
Paul committed
525

Paul's avatar
Paul committed
526
527
528
529
530
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
    return ins->get_shape().transposed();
}

Paul's avatar
Paul committed
531
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
Paul's avatar
Paul committed
532
{
Paul's avatar
Paul committed
533
    if(ins->inputs().empty())
Paul's avatar
Paul committed
534
535
        return false;
    auto s = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
536
537
    return std::all_of(
        ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
Paul's avatar
Paul committed
538
539
}

540
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
Paul's avatar
Add cbr  
Paul committed
541
542
{
    if(ins->outputs().size() == 1)
543
544
        return {ins->outputs().front()};
    return nullopt;
Paul's avatar
Add cbr  
Paul committed
545
546
}

Paul's avatar
Paul committed
547
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Paul committed
548
549
{
    if(ins->outputs().size() == 1)
550
551
552
553
        return {ins};
    if(ins->outputs().empty() and ctx.is_last(ins))
        return {ins};
    return nullopt;
Paul's avatar
Paul committed
554
555
}

Paul's avatar
Paul committed
556
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
Paul's avatar
Paul committed
557

Paul's avatar
Paul committed
558
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
Paul's avatar
Paul committed
559
{
560
561
562
    if(ins->outputs().empty() and not ctx.is_last(ins))
        return {ins};
    return nullopt;
Paul's avatar
Paul committed
563
564
}

565
566
567
568
569
template <class... Ms>
auto skip(Ms... ms)
{
    auto m = any_of(ms...);
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
570
571
572
573
574
575
576
577
578
        return fix<optional<instruction_ref>>(
            [&](auto self, auto ins) -> optional<instruction_ref> {
                if(ins->inputs().size() == 1 and ctx.matched(m, ins))
                {
                    auto next = ins->inputs().front();
                    return self(next);
                }
                return ins;
            })(start);
579
580
581
    });
}

Paul's avatar
Paul committed
582
template <class... Ms>
Paul's avatar
Paul committed
583
584
585
586
auto skip_output(Ms... ms)
{
    auto m = any_of(ms...);
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
587
588
589
        return fix<optional<instruction_ref>>(
            [&](auto self, auto ins) -> optional<instruction_ref> {
                if(ins->outputs().size() == 1)
Paul's avatar
Paul committed
590
                {
591
592
593
594
595
596
597
598
                    auto next = ins->outputs().front();
                    if(ctx.matched(m, next))
                    {
                        auto skipped_next = self(next);
                        if(skipped_next)
                            return skipped_next;
                    }
                    return next;
Paul's avatar
Paul committed
599
                }
600
601
                return nullopt;
            })(start);
Paul's avatar
Paul committed
602
603
604
    });
}

605
606
607
608
609
610
611
612
613
614
615
616
inline auto var(std::string s)
{
    return make_basic_fun_matcher(
        [=, s = std::move(s)](const matcher_context& ctx,
                              instruction_ref) -> optional<instruction_ref> {
            auto it = ctx.instructions.find(s);
            if(it == ctx.instructions.end())
                return nullopt;
            return it->second;
        });
}

Paul's avatar
Paul committed
617
618
619
inline auto name(std::string s)
{
    return make_basic_pred_matcher(
bpickrel's avatar
bpickrel committed
620
        [=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; });
Paul's avatar
Paul committed
621
622
}

Shucai Xiao's avatar
Shucai Xiao committed
623
624
625
626
627
628
inline auto name_contains(const std::string& name)
{
    return make_basic_pred_matcher(
        [=](instruction_ref ins) { return contains(ins->get_operator().name(), name); });
}

Paul's avatar
Paul committed
629
inline auto name(std::unordered_set<std::string> names)
Paul's avatar
Paul committed
630
{
bpickrel's avatar
bpickrel committed
631
    return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) {
Paul's avatar
Paul committed
632
633
        return names.count(ins->name()) > 0;
    });
Paul's avatar
Paul committed
634
635
}

Paul's avatar
Paul committed
636
template <class... Ts>
Paul's avatar
Paul committed
637
inline auto name(std::string s, Ts... xs) // NOLINT
Paul's avatar
Paul committed
638
{
Paul's avatar
Paul committed
639
    return name(std::unordered_set<std::string>{std::move(s), std::move(xs)...});
Paul's avatar
Paul committed
640
641
}

642
643
inline auto nargs(std::size_t n)
{
Paul's avatar
Paul committed
644
    return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
645
646
}

Paul's avatar
Paul committed
647
648
inline auto arg(std::size_t i)
{
649
650
651
652
653
654
    return make_basic_fun_matcher(
        [=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
            if(i < ins->inputs().size())
                return ins->inputs()[i];
            return nullopt;
        });
Paul's avatar
Paul committed
655
656
657
}

// Workaround for bugs in clang
Paul's avatar
Paul committed
658
659
660
661
template <std::size_t...>
struct args_impl_ints
{
};
Paul's avatar
Paul committed
662

Paul's avatar
Paul committed
663
template <std::size_t... Ns, class... Ms>
Paul's avatar
Paul committed
664
665
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
Paul's avatar
Paul committed
666
    return match::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
Paul's avatar
Paul committed
667
668
}

Paul's avatar
Paul committed
669
template <class... Ms>
Paul's avatar
Paul committed
670
671
672
673
674
auto args(Ms... ms)
{
    return sequence_c<sizeof...(Ms)>([=](auto... is) {
        // It needs to be written as `decltype(is)::value` for gcc 5
        return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
Paul's avatar
Paul committed
675
    });
Paul's avatar
Paul committed
676
677
}

Paul's avatar
Paul committed
678
inline auto either_arg(std::size_t i, std::size_t j)
Paul's avatar
Paul committed
679
680
{
    return [=](auto m1, auto m2) {
Paul's avatar
Paul committed
681
682
        return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)),
                             match::all_of(arg(j)(m1), arg(i)(m2)));
Paul's avatar
Paul committed
683
684
685
    };
}

kahmed10's avatar
kahmed10 committed
686
687
688
689
690
inline auto any_arg(std::size_t i, std::size_t j)
{
    return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); };
}

691
692
693
694
695
template <std::size_t N, class M>
std::size_t tree_leafs_impl(matcher_context& ctx,
                            std::array<instruction_ref, N>& leafs,
                            M m,
                            instruction_ref ins)
Paul Fultz II's avatar
Paul Fultz II committed
696
697
698
699
700
{
    std::size_t idx = 0;
    fix([&](auto self, auto i) {
        if(idx == leafs.size())
            return;
701
        if(ctx.matched(m, i) and i->inputs().size() >= 2)
Paul Fultz II's avatar
Paul Fultz II committed
702
703
704
705
706
707
708
709
710
711
712
        {
            self(i->inputs()[0]);
            self(i->inputs()[1]);
            return;
        }
        leafs[idx] = i;
        idx++;
    })(ins);
    return idx;
}

713
714
template <class M, class... Ms>
auto tree(M main_op, Ms... ms)
Paul Fultz II's avatar
Paul Fultz II committed
715
{
716
717
718
719
720
721
722
723
    return make_basic_fun_matcher(
        [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
            // Flatten leaf nodes
            std::array<instruction_ref, sizeof...(Ms)> leafs;
            std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
            if(idx != leafs.size())
                return nullopt;
            // Use explicit captures to workaround ICE on gcc
724
725
            // Capture by value to workaround compile error on gcc 9
            bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
726
727
728
729
730
                return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
            });
            if(not found)
                return nullopt;
            return ins;
Paul Fultz II's avatar
Paul Fultz II committed
731
732
733
        });
}

734
735
template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms)
Paul Fultz II's avatar
Paul Fultz II committed
736
{
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    return make_basic_fun_matcher(
        [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
            // Flatten leaf nodes
            std::array<instruction_ref, sizeof...(Ms)> leafs;
            std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
            if(idx != leafs.size())
                return nullopt;
            // Use explicit captures to workaround ICE on gcc
            bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
                return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
                    return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
                })(ms...)();
            });
            if(not found)
                return nullopt;
            return ins;
Paul Fultz II's avatar
Paul Fultz II committed
753
754
755
        });
}

Paul's avatar
Paul committed
756
template <class M>
Paul's avatar
Paul committed
757
758
auto same_shape(M m)
{
759
760
761
762
763
764
765
    return make_basic_fun_matcher(
        [=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
            auto i = m.match(ctx, ins);
            if(i and (*i)->get_shape() == ins->get_shape())
                return ins;
            return nullopt;
        });
Paul's avatar
Paul committed
766
767
}

Paul's avatar
Paul committed
768
template <class... Ms>
Paul's avatar
Paul committed
769
770
771
772
773
auto same_shape(Ms... ms)
{
    return all_of(same_shape(ms)...);
}

774
775
776
777
778
779
template <class... Ms>
auto skip_broadcasts(Ms... ms)
{
    return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}

780
781
782
783
784
785
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
    return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}

kahmed10's avatar
kahmed10 committed
786
787
788
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
789
    return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
kahmed10's avatar
kahmed10 committed
790
        if(ins->name() != "@literal")
kahmed10's avatar
kahmed10 committed
791
792
793
794
795
            return false;
        auto l = ins->get_literal();
        if(l.empty())
            return false;
        bool b = false;
kahmed10's avatar
kahmed10 committed
796
        l.visit([&](auto v) {
Paul Fultz II's avatar
Paul Fultz II committed
797
798
            if(std::all_of(
                   v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; }))
kahmed10's avatar
kahmed10 committed
799
800
                b = true;
        });
kahmed10's avatar
kahmed10 committed
801
        return b;
802
    }));
kahmed10's avatar
kahmed10 committed
803
804
}

805
806
807
808
809
810
inline auto has_attribute(const std::string& name)
{
    return make_basic_pred_matcher(
        [=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}

811
812
813
814
815
816
817
template <class... Ms>
auto pointwise(Ms... ms)
{
    return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
                                             ms...);
}

Paul's avatar
Paul committed
818
} // namespace match
Paul's avatar
Paul committed
819
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
820
} // namespace migraphx
Paul's avatar
Paul committed
821
822

#endif