fuse_ck.cpp 7.67 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
7
8
9
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
10
11
12
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_GEMM_FUSION);

Paul's avatar
Paul committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
struct module;

namespace gpu {

struct ck_gemm
{
    operation op = make_op("dot");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    std::string name() const { return "gpu::ck_gemm"; }
Paul's avatar
Paul committed
28
29
30

    void check_gemm_shape(const shape& s) const
    {
Paul's avatar
Format  
Paul committed
31
        if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
Paul's avatar
Paul committed
32
33
34
            MIGRAPHX_THROW("Invalid shape for ck_gemm");
    }

Paul's avatar
Paul committed
35
36
    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
Paul's avatar
Paul committed
37
        check_shapes{inputs, *this}.same_ndims();
Paul's avatar
Paul committed
38
39
40
41
        // if(mods.size() != 1)
        //     MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
Paul's avatar
Paul committed
42
43
        auto a = inputs[0];
        auto b = inputs[1];
Paul's avatar
Format  
Paul committed
44
        for(const auto& input : inputs)
Paul's avatar
Paul committed
45
            check_gemm_shape(input);
Paul's avatar
Paul committed
46
        auto r = op.compute_shape({a, b});
Paul's avatar
Format  
Paul committed
47
        if(mods.empty())
Paul's avatar
Paul committed
48
49
            return r;
        return r.with_type(mods.front()->get_output_shapes().front().type());
Paul's avatar
Paul committed
50
51
52
53
    }
};
MIGRAPHX_REGISTER_OP(ck_gemm);

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
struct ck_gemm_int8
{
    operation op = make_op("quant_dot");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    std::string name() const { return "gpu::ck_gemm_int8"; }

    void check_gemm_shape(const shape& s) const
    {
        if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
            MIGRAPHX_THROW("Invalid shape for ck_gemm");
    }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
        check_shapes{inputs, *this}.same_ndims();
        // if(mods.size() != 1)
        //     MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
        auto a = inputs[0];
        auto b = inputs[1];
        for(const auto& input : inputs)
            check_gemm_shape(input);
        auto r = op.compute_shape({a, b});
        if(mods.empty())
            return r.with_type(migraphx::shape::int8_type);
        return r.with_type(mods.front()->get_output_shapes().front().type());
    }
};
MIGRAPHX_REGISTER_OP(ck_gemm_int8);

Paul's avatar
Paul committed
91
92
93
94
namespace {

MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
95
    if(ins->name() != "dot" and ins->name() != "quant_dot")
Paul's avatar
Paul committed
96
97
98
        return false;
    auto a = ins->inputs().front()->get_shape();
    auto b = ins->inputs().back()->get_shape();
Alan Turner's avatar
Cleanup  
Alan Turner committed
99
100
    if(a.lens().back() > 2048)
        return false;
Paul's avatar
Paul committed
101
    return true;
Paul's avatar
Paul committed
102
103
}

Paul's avatar
Paul committed
104
struct find_ck_gemm_pointwise
Paul's avatar
Paul committed
105
{
Paul's avatar
Paul committed
106
    // Find a gemm followed by a pointwise operation.
Paul's avatar
Format  
Paul committed
107
108
109
110
    auto matcher() const
    {
        auto gemm =
            match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
Paul's avatar
Paul committed
111
112
        return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
    }
Paul's avatar
Paul committed
113
114
115

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Paul committed
116
117
118
119
120
121
        auto ins      = r.result;
        auto gemm_ins = r.instructions["gemm"];
        auto x_ins    = r.instructions["x"]; // input after contiguous
        auto* pm      = ins->module_inputs().front();
        auto names    = pm->get_parameter_names();
        std::sort(names.begin(), names.end());
Paul's avatar
Format  
Paul committed
122
123
        auto inputs   = ins->inputs();
        auto gemm_it  = std::find(inputs.begin(), inputs.end(), x_ins);
Paul's avatar
Paul committed
124
125
        auto gemm_idx = gemm_it - inputs.begin();
        assert(gemm_it != inputs.end());
Paul's avatar
Format  
Paul committed
126
        if(ins->get_shape().type() != shape::half_type)
Paul's avatar
Paul committed
127
            return;
Paul's avatar
Format  
Paul committed
128
        if(gemm_idx != 0)
Paul's avatar
Paul committed
129
        {
Paul's avatar
Format  
Paul committed
130
131
            auto first_param    = pm->get_parameter(names[0]);
            auto gemm_param     = pm->get_parameter(names[gemm_idx]);
132
            auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
Paul's avatar
Format  
Paul committed
133
            auto new_first_param =
134
                pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
Paul's avatar
Paul committed
135
136
137
138
139
140
141
142
143
            pm->replace_instruction(gemm_param, new_gemm_param);
            pm->replace_instruction(first_param, new_first_param);
            pm->remove_instruction(first_param);
            pm->remove_instruction(gemm_param);
        }
        inputs.erase(gemm_it);
        inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());

        mpm.get_module().replace_instruction(ins, ck_gemm{}, inputs, {pm});
Paul's avatar
Paul committed
144
145
146
    }
};

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
struct find_ck_gemm_pointwise_int8
{
    // Find a gemm followed by a pointwise operation.
    auto matcher() const
    {
        auto gemm =
            match::skip(match::name("contiguous"))(match::name("quant_dot")(is_ck_gemm().bind("gemm")));
        return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto ins      = r.result;
        auto gemm_ins = r.instructions["gemm"];
        auto x_ins    = r.instructions["x"]; // input after contiguous
        auto next_ins = std::next(ins);
        auto* pm      = ins->module_inputs().front();
        auto names    = pm->get_parameter_names();
        
        std::sort(names.begin(), names.end());
        auto inputs   = ins->inputs();
        auto gemm_it  = std::find(inputs.begin(), inputs.end(), x_ins);
        auto gemm_idx = gemm_it - inputs.begin();
        assert(gemm_it != inputs.end());
        if(gemm_idx != 0)
        {
            auto first_param    = pm->get_parameter(names[0]);
            auto gemm_param     = pm->get_parameter(names[gemm_idx]);
            auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
            auto new_first_param =
                pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
            pm->replace_instruction(gemm_param, new_gemm_param);
            pm->replace_instruction(first_param, new_first_param);
            pm->remove_instruction(first_param);
            pm->remove_instruction(gemm_param);
        }
        inputs.erase(gemm_it);
        inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
        mpm.get_module().replace_instruction(ins, ck_gemm_int8{}, inputs, {pm});
    }
};

Paul's avatar
Paul committed
189
190
191
192
193
194
195
196
197
198
199
struct find_ck_gemm
{
    auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto ins = r.result;
        mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
    }
};

200
201
202
203
204
205
206
207
208
209
210
struct find_ck_gemm_int8
{
    auto matcher() const { return match::name("quant_dot")(is_ck_gemm().bind("gemm")); }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto ins = r.result;
        mpm.get_module().replace_instruction(ins, ck_gemm_int8{ins->get_operator()}, ins->inputs());
    }
};

Paul's avatar
Paul committed
211
212
} // namespace

Paul's avatar
Format  
Paul committed
213
214
void fuse_ck::apply(module_pass_manager& mpm) const
{
Paul's avatar
Format  
Paul committed
215
    if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
216
    {
Paul's avatar
Paul committed
217
        match::find_matches(mpm, find_ck_gemm_pointwise{});
218
219
        match::find_matches(mpm, find_ck_gemm_pointwise_int8{});
    }
Paul's avatar
Format  
Paul committed
220
    if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
221
    {
Paul's avatar
Paul committed
222
        match::find_matches(mpm, find_ck_gemm{});
223
224
        match::find_matches(mpm, find_ck_gemm_int8{});
    }
Paul's avatar
Paul committed
225
}
Paul's avatar
Paul committed
226
227
228
229
230

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx