prefuse_ops.cpp 3.15 KB
Newer Older
1
2
3
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/register_op.hpp>
5
6
7
8

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
Paul's avatar
Paul committed
9
namespace {
Paul's avatar
Paul committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
struct layernorm
{
    std::string name() const { return "gpu::prelayernorm"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto s = inputs.at(0);
        if(s.scalar())
        {
            return s;
        }
        else if(s.broadcasted())
        {
            return {s.type(), s.lens()};
        }
        else
        {
            return s.with_lens(s.lens());
        }
    }
};
Paul's avatar
Paul committed
32
MIGRAPHX_REGISTER_OP(layernorm);
Paul's avatar
Paul committed
33

34
35
36
37
struct find_layernorm
{
    auto matcher() const { return match::layernorm(); }

Paul's avatar
Paul committed
38
39
40
41
42
43
44
45
46
47
48
49
50
    void apply(module& m, const match::matcher_result& r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];

        m.replace_instruction(ins, layernorm{}, x_ins);
    }
};

struct find_gpulayernorm
{
    auto matcher() const { return match::layernorm(); }

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    void apply(module& m, const match::matcher_result& r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];

        if(not x_ins->get_shape().standard())
            x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);

        auto relements = x_ins->get_shape().lens().back();

        if(relements > 1024 or (relements % 4 != 0 and relements > 256))
            return;

        auto a = m.insert_instruction(
            ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
        m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
    }
};

Paul's avatar
Paul committed
70
struct find_gputriaddlayernorm
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
{
    auto matcher() const
    {
        auto add1 =
            match::name("add")(match::none_of(match::is_constant()),
                               match::args(match::any().bind("z1"), match::any().bind("z2")));
        auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
        return match::layernorm()(match::var("x")(add2));
    }

    void apply(module& m, const match::matcher_result& r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["z1"];
        auto y_ins = r.instructions["z2"];
        auto z_ins = r.instructions["z3"];

        for(auto* pins : {&x_ins, &y_ins, &z_ins})
        {
            if(not(*pins)->get_shape().standard())
                *pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
        }

        auto relements = x_ins->get_shape().lens().back();

        if(relements > 1024 or (relements % 4 != 0 and relements > 256))
            return;

        auto a = m.insert_instruction(
            ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
        m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
    }
};
} // namespace

void prefuse_ops::apply(module& m) const
{
Paul's avatar
Paul committed
108
109
    match::find_matches(m, find_layernorm{});
    // match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{});
110
111
112
113
114
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx