#include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace cpu { struct dnnl_gemm : dnnl_extend_op { std::vector arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; } void required(const check_shapes& cs) const { cs.not_broadcasted(); } dnnl::matmul::desc get_desc(const std::unordered_map& m) const { return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)}; } }; } // namespace cpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx