gemm.cpp 1.01 KB
Newer Older
1
2
3
4
5
6
7
8
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
namespace migraphx {
Paul's avatar
Paul committed
11
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
12
13
namespace cpu {

14
struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
Paul's avatar
Paul committed
15
{
16
17
18
19
20
21
    std::vector<int> arg_map(int) const
    {
        return {MIGRAPHX_DNNL_PREFIX(ARG_SRC),
                MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS),
                MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
    }
22

23
    void required(const check_shapes& cs) const { cs.not_broadcasted(); }
Paul's avatar
Paul committed
24

25
26
    dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
    {
27
28
29
        return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
                m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
                m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
30
31
    }
};
32

Paul's avatar
Paul committed
33
} // namespace cpu
Paul's avatar
Paul committed
34
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
35
} // namespace migraphx