gemm.cpp 5.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/dnnl.hpp>
#include <migraphx/cpu/migemm.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
namespace migraphx {
Paul's avatar
Paul committed
12
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
13
14
namespace cpu {

15
16
#if USE_DNNL
struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
Paul's avatar
Paul committed
17
{
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }

    // Batch must be a single dimension
    shape adjust_shape(shape x, int) const
    {
        auto s     = base_adjust_shape(x);
        auto ndims = s.lens().size();
        if(ndims > 3)
        {
            if(not std::is_sorted(
                   s.strides().begin(), s.strides().begin() + (ndims - 2), std::greater<>{}))
                MIGRAPHX_THROW("Batch transposed");
            std::size_t batch = std::accumulate(
                s.lens().begin(), s.lens().begin() + (ndims - 2), 1, std::multiplies<>{});
            shape s3d{s.type(),
                      {batch, s.lens()[ndims - 2], s.lens()[ndims - 1]},
                      {s.lens()[ndims - 2] * s.lens()[ndims - 1],
                       s.strides()[ndims - 2],
                       s.strides()[ndims - 1]}};
            return s3d;
        }
        else
        {
            return s;
        }
    }
Paul's avatar
Paul committed
44

45
46
47
48
49
50
    dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
    {
        return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)};
    }
};
#endif
Paul's avatar
Paul committed
51

52
struct cpu_gemm : auto_register_op<cpu_gemm>
Paul's avatar
Paul committed
53
{
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    op::dot op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }
    std::string name() const { return "cpu::dot"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.standard();
        inputs.pop_back();
        return op.compute_shape(inputs);
    }

    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }

    argument compute(context&, const shape&, std::vector<argument> args) const
    {
        // 3 inputs, it is alpha * A * B + beta * C, then
        // A and B are matrices, and C is of the same shape as A * B
        if(args.size() == 4)
        {
            // no need to consider the value of args[2]
            if(op.beta == 0.0f)
Shucai Xiao's avatar
Shucai Xiao committed
82
            {
83
84
85
86
87
88
89
                args.back().visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
            }
            else
            {
                visit_all(args.back(), args[2])([&](auto output, auto input) {
                    std::copy(input.begin(), input.end(), output.begin());
                });
Shucai Xiao's avatar
Shucai Xiao committed
90
            }
Paul's avatar
Paul committed
91

92
93
94
95
96
97
98
            migemm(args.back(), args[0], args[1], op.alpha, op.beta);

            return args.back();
        }

        // 2 input arguments
        migemm(args.back(), args[0], args[1], op.alpha, 0.0f);
Paul's avatar
Paul committed
99

100
101
102
103
104
        return args.back();
    }
};

struct cpu_quant_gemm : auto_register_op<cpu_quant_gemm>
Paul's avatar
Paul committed
105
{
106
107
108
109
    op::quant_dot op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
110
    {
111
        return migraphx::reflect(self.op, f);
112
    }
113
114
115

    std::string name() const { return "cpu::quant_dot"; }
    shape compute_shape(std::vector<shape> inputs) const
116
    {
117
118
119
        check_shapes{inputs, *this}.standard();
        inputs.pop_back();
        return op.compute_shape(inputs);
120
    }
Paul's avatar
Paul committed
121

122
123
124
125
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Paul committed
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    argument compute(context&, const shape&, std::vector<argument> args) const
    {
        // 3 inputs, it is alpha * A * B + beta * C, then
        // A and B are matrices, and C is of the same shape to A * B

        // first, convert the args[0] and args[1] from int8_t to int32_t
        argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
        argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
        arg_0.visit([&](auto output) {
            args.at(0).visit(
                [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
        });

        arg_1.visit([&](auto output) {
            args.at(1).visit(
                [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
        });

        if(args.size() == 4)
        {
            // no need to consider the value of args[2]
            if(op.beta == 0)
            {
                args.back().visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
            }
            else
            {
                visit_all(args.back(), args[2])([&](auto output, auto input) {
                    std::copy(input.begin(), input.end(), output.begin());
                });
            }

            migemm(args.back(), arg_0, arg_1, op.alpha, op.beta);

            return args.back();
        }

        // 2 input arguments
        migemm(args.back(), arg_0, arg_1, op.alpha, int32_t{0});

        return args.back();
    }
};
170

Paul's avatar
Paul committed
171
} // namespace cpu
Paul's avatar
Paul committed
172
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
173
} // namespace migraphx