#ifndef GEMM_UTILS_HPP #define GEMM_UTILS_HPP #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" namespace ck { namespace gemm_util { struct GemmParams { GemmParams() : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) { } ck::index_t M; ck::index_t N; ck::index_t K; ck::index_t StrideA; ck::index_t StrideB; ck::index_t StrideC; float alpha; float beta; }; template void RunHostGEMM(const Tensor& A, const Tensor& B, Tensor& C, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { auto ref_gemm = GemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); } template void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, const ck::gemm_util::GemmParams& params, const Tensor& A, const Tensor& B, Tensor& C, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); a_m_k_device_buf.ToDevice(A.mData.data()); b_k_n_device_buf.ToDevice(B.mData.data()); auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto argument_ptr = gemmPtr->MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), params.M, params.N, params.K, params.StrideA, params.StrideB, params.StrideC, a_element_op, b_element_op, c_element_op); if(!gemmPtr->IsSupportedArgument(argument_ptr.get())) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } invoker_ptr->Run(argument_ptr.get()); c_m_n_device_buf.FromDevice(C.mData.data()); } } // namespace gemm_util } // namespace ck #endif