remap.cpp 1.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
    auto matcher() const
    {
        return match::name("add")(match::any_of(
            match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
            match::args(match::used_once().bind("a"),
                        match::name("dot")(match::nargs(2)).bind("dot"))));
    }

25
    void apply(module& p, match::matcher_result r) const
26
27
28
29
30
31
32
33
34
35
36
37
38
    {
        auto ins     = r.result;
        auto dot_ins = r.instructions["dot"];
        auto a_ins   = r.instructions["a"];

        auto dot = any_cast<op::dot>(dot_ins->get_operator());

        dot.beta = 1;
        p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
    }
};
} // namespace

39
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
40
41
42

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx