#include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace cpu { struct dnnl_gemm : dnnl_extend_op { std::vector arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS), MIGRAPHX_DNNL_PREFIX(ARG_BIAS)}; } void required(const check_shapes& cs) const { cs.not_broadcasted(); } dnnl::matmul::desc get_desc(const std::unordered_map& m) const { return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))}; } }; } // namespace cpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx