gemm.cpp 1.61 KB
Newer Older
1
2
3
4
5
6
7
8
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
namespace migraphx {
Paul's avatar
Paul committed
11
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
12
13
namespace cpu {

14
struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
Paul's avatar
Paul committed
15
{
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }

    // Batch must be a single dimension
    shape adjust_shape(shape x, int) const
    {
        auto s     = base_adjust_shape(x);
        auto ndims = s.lens().size();
        if(ndims > 3)
        {
            if(not std::is_sorted(
                   s.strides().begin(), s.strides().begin() + (ndims - 2), std::greater<>{}))
                MIGRAPHX_THROW("Batch transposed");
            std::size_t batch = std::accumulate(
                s.lens().begin(), s.lens().begin() + (ndims - 2), 1, std::multiplies<>{});
            shape s3d{s.type(),
                      {batch, s.lens()[ndims - 2], s.lens()[ndims - 1]},
                      {s.lens()[ndims - 2] * s.lens()[ndims - 1],
                       s.strides()[ndims - 2],
                       s.strides()[ndims - 1]}};
            return s3d;
        }
        else
        {
            return s;
        }
    }
Paul's avatar
Paul committed
42

43
44
45
46
47
    dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
    {
        return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)};
    }
};
48

Paul's avatar
Paul committed
49
} // namespace cpu
Paul's avatar
Paul committed
50
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
51
} // namespace migraphx