#include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace cpu { struct dnnl_gemm : dnnl_extend_op { std::vector arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; } // Batch must be a single dimension shape adjust_shape(shape x, int) const { auto s = base_adjust_shape(x); auto ndims = s.lens().size(); if(ndims > 3) { if(not std::is_sorted( s.strides().begin(), s.strides().begin() + (ndims - 2), std::greater<>{})) MIGRAPHX_THROW("Batch transposed"); std::size_t batch = std::accumulate( s.lens().begin(), s.lens().begin() + (ndims - 2), 1, std::multiplies<>{}); shape s3d{s.type(), {batch, s.lens()[ndims - 2], s.lens()[ndims - 1]}, {s.lens()[ndims - 2] * s.lens()[ndims - 1], s.strides()[ndims - 2], s.strides()[ndims - 1]}}; return s3d; } else { return s; } } dnnl::matmul::desc get_desc(const std::unordered_map& m) const { return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)}; } }; } // namespace cpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx