fuse_ck.cpp 3.54 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module;

namespace gpu {

struct ck_gemm
{
    operation op = make_op("dot");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    std::string name() const { return "gpu::ck_gemm"; }
Paul's avatar
Paul committed
25
26
27

    void check_gemm_shape(const shape& s) const
    {
Paul's avatar
Format  
Paul committed
28
        if(contains(s.lens(), 1))
Paul's avatar
Paul committed
29
30
31
            MIGRAPHX_THROW("Invalid shape for ck_gemm");
    }

Paul's avatar
Paul committed
32
33
    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
Paul's avatar
Paul committed
34
        check_shapes{inputs, *this}.same_ndims();
Paul's avatar
Paul committed
35
36
37
38
        // if(mods.size() != 1)
        //     MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
Paul's avatar
Paul committed
39
40
41
42
        auto a = inputs[0];
        auto b = inputs[1];
        for(const auto& input:inputs)
            check_gemm_shape(input);
Paul's avatar
Paul committed
43
44
45
46
47
48
49
50
51
52
53
54
55
        return op.compute_shape({a, b});
    }
};
MIGRAPHX_REGISTER_OP(ck_gemm);

namespace {

MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
    if(ins->name() != "dot")
        return false;
    auto a = ins->inputs().front()->get_shape();
    auto b = ins->inputs().back()->get_shape();
Paul's avatar
Format  
Paul committed
56
    if(a.lens().size() > 2 or b.lens().size() > 2)
Paul's avatar
Paul committed
57
        return false;
Paul's avatar
Paul committed
58
59
    if (a.lens()[1] >= 2048)
        return false;
Paul's avatar
Format  
Paul committed
60
61
    return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
            b.lens()[1] % 8 == 0);
Paul's avatar
Paul committed
62
63
64
65
}

struct find_ck_gemm
{
Paul's avatar
Paul committed
66
67
68
69
70
    // Find a gemm followed by a pointwise operation.
    auto matcher() const {
        auto gemm = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm"))); 
        return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
    }
Paul's avatar
Paul committed
71
72
73

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Paul committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        auto ins      = r.result;
        auto gemm_ins = r.instructions["gemm"];
        auto x_ins    = r.instructions["x"]; // input after contiguous
        auto* pm      = ins->module_inputs().front();
        auto names    = pm->get_parameter_names();
        std::sort(names.begin(), names.end());
        auto inputs = ins->inputs();
        auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
        auto gemm_idx = gemm_it - inputs.begin();
        assert(gemm_it != inputs.end());
        if (gemm_idx != 0)
        {
            // std::swap(inputs[0], inputs[gemm_idx]);
            auto first_param = pm->get_parameter(names[0]);
            auto gemm_param = pm->get_parameter(names[gemm_idx]);
            auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
            auto new_first_param = pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
            pm->replace_instruction(gemm_param, new_gemm_param);
            pm->replace_instruction(first_param, new_first_param);
            pm->remove_instruction(first_param);
            pm->remove_instruction(gemm_param);
        }
        inputs.erase(gemm_it);
        inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());

        mpm.get_module().replace_instruction(ins, ck_gemm{}, inputs, {pm});
Paul's avatar
Paul committed
100
101
102
103
104
    }
};

} // namespace

Paul's avatar
Format  
Paul committed
105
void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm{}); }
Paul's avatar
Paul committed
106
107
108
109
110

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx