fuse_ck.cpp 3.92 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module;

namespace gpu {

struct ck_gemm
{
    operation op = make_op("dot");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    std::string name() const { return "gpu::ck_gemm"; }
Paul's avatar
Paul committed
25
26
27

    void check_gemm_shape(const shape& s) const
    {
Paul's avatar
Format  
Paul committed
28
        if(contains(s.lens(), 1))
Paul's avatar
Paul committed
29
30
31
            MIGRAPHX_THROW("Invalid shape for ck_gemm");
    }

Paul's avatar
Paul committed
32
33
    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
Paul's avatar
Paul committed
34
        check_shapes{inputs, *this}.not_broadcasted();
Paul's avatar
Paul committed
35
36
37
38
39
40
41
        // if(mods.size() != 1)
        //     MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
        auto n = inputs.size();
        auto a = inputs[n - 2];
        auto b = inputs[n - 1];
Paul's avatar
Paul committed
42
43
        check_gemm_shape(a);
        check_gemm_shape(b);
Paul's avatar
Paul committed
44
45
46
47
48
        return op.compute_shape({a, b});
    }
};
MIGRAPHX_REGISTER_OP(ck_gemm);

Alan Turner's avatar
Alan Turner committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
struct ck_batched_gemm
{
    operation op = make_op("dot");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    std::string name() const { return "gpu::ck_batched_gemm"; }

    void check_gemm_shape(const shape& s) const
    {
        if(contains(s.lens(), 1))
            MIGRAPHX_THROW("Invalid shape for ck_batched_gemm");
    }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
        check_shapes{inputs, *this}.not_broadcasted();
        // if(mods.size() != 1)
        //     MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
        auto n = inputs.size();
        auto a = inputs[n - 2];
        auto b = inputs[n - 1];
        check_gemm_shape(a);
        check_gemm_shape(b);
        return op.compute_shape({a, b});
    }
};
MIGRAPHX_REGISTER_OP(ck_batched_gemm);

Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
namespace {

MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
    if(ins->name() != "dot")
        return false;
    auto a = ins->inputs().front()->get_shape();
    auto b = ins->inputs().back()->get_shape();
Paul's avatar
Format  
Paul committed
92
    if(a.lens().size() > 2 or b.lens().size() > 2)
Paul's avatar
Paul committed
93
        return false;
Alan Turner's avatar
Alan Turner committed
94
    if(a.lens()[1] > 1024)
Paul's avatar
Paul committed
95
        return false;
Alan Turner's avatar
Alan Turner committed
96

97
    return true;
Paul's avatar
Paul committed
98
99
}

Alan Turner's avatar
Alan Turner committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
MIGRAPHX_PRED_MATCHER(is_ck_batched_gemm, instruction_ref ins)
{
    if(ins->name() != "dot")
        return false;
    auto a = ins->inputs().front()->get_shape();
    auto b = ins->inputs().back()->get_shape();
    if(a.lens().size() < 3 or b.lens().size() < 3)
        return false;
    if(a.lens().back() > 1024)
        return false;

    return true;
}

Paul's avatar
Paul committed
114
115
struct find_ck_gemm
{
Alan Turner's avatar
Alan Turner committed
116
    // Find a gemm that can be replaced with a ck_gemm
Paul's avatar
Format  
Paul committed
117
    auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
Paul's avatar
Paul committed
118
119
120

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Format  
Paul committed
121
        auto ins = r.result;
Paul's avatar
Paul committed
122
123
124
125
        mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
    }
};

Alan Turner's avatar
Alan Turner committed
126
127
128
129
130
131
132
133
134
135
136
137
struct find_ck_batched_gemm
{
    // Find a batched gemm that can be replaced with a ck_batched_gemm
    auto matcher() const { return match::name("dot")(is_ck_batched_gemm().bind("gemm")); }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto ins = r.result;
        mpm.get_module().replace_instruction(ins, ck_batched_gemm{ins->get_operator()}, ins->inputs());
    }
};

Paul's avatar
Paul committed
138
139
} // namespace

Alan Turner's avatar
Alan Turner committed
140
141
142
143
144
void fuse_ck::apply(module_pass_manager& mpm) const 
{ 
    match::find_matches(mpm, find_ck_gemm{}); 
    match::find_matches(mpm, find_ck_batched_gemm{});    
}
Paul's avatar
Paul committed
145
146
147
148
149

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx