gemm.cpp 842 Bytes
Newer Older
1
2
3
4
5
6
7
8
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
namespace migraphx {
Paul's avatar
Paul committed
11
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
12
13
namespace cpu {

14
struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
Paul's avatar
Paul committed
15
{
16
17
    std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }

18
    void required(const check_shapes& cs) const { cs.not_broadcasted(); }
Paul's avatar
Paul committed
19

20
21
22
23
24
    dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
    {
        return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)};
    }
};
25

Paul's avatar
Paul committed
26
} // namespace cpu
Paul's avatar
Paul committed
27
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
28
} // namespace migraphx