#ifndef GEMM_UTILS_HPP #define GEMM_UTILS_HPP #include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "reference_cgemm.hpp" #include "tensor_layout.hpp" namespace ck { namespace cgemm_util { struct CGemmParams { CGemmParams() : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) { } ck::index_t M; ck::index_t N; ck::index_t K; ck::index_t StrideA; ck::index_t StrideB; ck::index_t StrideC; float alpha; float beta; }; template void RunHostCGEMM(const Tensor& A_real, const Tensor& A_imag, const Tensor& B_real, const Tensor& B_imag, Tensor& C_real, Tensor& C_imag, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { auto ref_cgemm = CGemmInstance{}; auto ref_invoker = ref_cgemm.MakeInvoker(); auto ref_argument = ref_cgemm.MakeArgument( A_real, A_imag, B_real, B_imag, C_real, C_imag, a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); } template void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, const ck::cgemm_util::CGemmParams& params, const Tensor& A_real, const Tensor& A_imag, const Tensor& B_real, const Tensor& B_imag, Tensor& C_real, Tensor& C_imag, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A_real.mDesc.GetElementSpace()); DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A_imag.mDesc.GetElementSpace()); DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B_real.mDesc.GetElementSpace()); DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B_imag.mDesc.GetElementSpace()); DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace()); DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace()); DeviceMem workspace_device_buf(cgemmPtr->GetWorkspaceSize( params.M, params.N, params.K, params.StrideA, params.StrideB, params.StrideC)); a_m_k_real_device_buf.ToDevice(A_real.mData.data()); a_m_k_imag_device_buf.ToDevice(A_imag.mData.data()); b_k_n_real_device_buf.ToDevice(B_real.mData.data()); b_k_n_imag_device_buf.ToDevice(B_imag.mData.data()); auto invoker_ptr = cgemmPtr->MakeInvokerPointer(); auto argument_ptr = cgemmPtr->MakeArgumentPointer( static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), static_cast(a_m_k_imag_device_buf.GetDeviceBuffer()), static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast(workspace_device_buf.GetDeviceBuffer()), params.M, params.N, params.K, params.StrideA, params.StrideB, params.StrideC, a_element_op, b_element_op, c_element_op); if(!cgemmPtr->IsSupportedArgument(argument_ptr.get())) { throw std::runtime_error( "wrong! device_cgemm with the specified compilation parameters does " "not support this CGEMM problem"); } invoker_ptr->Run(argument_ptr.get()); c_m_n_real_device_buf.FromDevice(C_real.mData.data()); c_m_n_imag_device_buf.FromDevice(C_imag.mData.data()); } template struct TestCGemm { auto PrepareCGemmTensor(const ck::cgemm_util::CGemmParams& params) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); } else { return HostTensorDescriptor(std::vector({row, col}), std::vector({1, stride})); } }; Tensor a_m_k_real( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor a_m_k_imag( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor b_k_n_real( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor b_k_n_imag( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor c_m_n_real_host_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_imag_host_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_real_device_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_imag_device_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); auto f_generate_tensor_value = [](auto& tensor, auto type) { using dataType = decltype(type); tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); }; f_generate_tensor_value(a_m_k_real, ADataType{}); f_generate_tensor_value(a_m_k_imag, ADataType{}); f_generate_tensor_value(b_k_n_real, BDataType{}); f_generate_tensor_value(b_k_n_imag, BDataType{}); return std::make_tuple(a_m_k_real, a_m_k_imag, b_k_n_real, b_k_n_imag, c_m_n_real_host_result, c_m_n_imag_host_result, c_m_n_real_device_result, c_m_n_imag_device_result); } auto operator()(DeviceCGemmPtr_& cgemmPtr) { std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name << ", CLayout = " << CLayout{}.name << std::endl; std::cout << cgemmPtr->GetTypeString() << std::endl; // Arrange ck::cgemm_util::CGemmParams params; params.M = 1024; params.N = 1024; params.K = 1024; params.StrideA = 1024; params.StrideB = 1024; params.StrideC = 1024; auto host_tensors = PrepareCGemmTensor(params); const Tensor& a_real = std::get<0>(host_tensors); const Tensor& a_imag = std::get<1>(host_tensors); const Tensor& b_real = std::get<2>(host_tensors); const Tensor& b_imag = std::get<3>(host_tensors); Tensor& c_host_real = std::get<4>(host_tensors); Tensor& c_host_imag = std::get<5>(host_tensors); Tensor& c_device_real = std::get<6>(host_tensors); Tensor& c_device_imag = std::get<7>(host_tensors); auto a_element_op = AElementwiseOperation{}; auto b_element_op = BElementwiseOperation{}; auto c_element_op = CElementwiseOperation{}; using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceCGemm; ck::cgemm_util::RunHostCGEMM(a_real, a_imag, b_real, b_imag, c_host_real, c_host_imag, a_element_op, b_element_op, c_element_op); // Act ck::cgemm_util::RunDeviceCGEMM(cgemmPtr, params, a_real, a_imag, b_real, b_imag, c_device_real, c_device_imag, a_element_op, b_element_op, c_element_op); // Assert bool res = false; if(std::is_same::value) { const bool res_real = ck::utils::check_err( c_device_real.mData, c_host_real.mData, "Error: incorrect results in real part!"); const bool res_imag = ck::utils::check_err(c_device_imag.mData, c_host_imag.mData, "Error: incorrect results in imaginary part!"); res = res_real && res_imag; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } else if(std::is_same::value) { const bool res_real = ck::utils::check_err( c_device_real.mData, c_host_real.mData, "Error: incorrect results in real part!"); const bool res_imag = ck::utils::check_err(c_device_imag.mData, c_host_imag.mData, "Error: incorrect results in imaginary part!"); res = res_real && res_imag; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } else if(std::is_same::value) { const bool res_real = ck::utils::check_err( c_device_real.mData, c_host_real.mData, "Error: incorrect results in real part!"); const bool res_imag = ck::utils::check_err(c_device_imag.mData, c_host_imag.mData, "Error: incorrect results in imaginary part!"); res = res_real && res_imag; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } return res; } }; template struct TestCGemmBF16 { using BF16 = ck::bhalf_t; auto PrepareCGemmTensorBF16(const ck::cgemm_util::CGemmParams& params) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); } else { return HostTensorDescriptor(std::vector({row, col}), std::vector({1, stride})); } }; // use fp32 host kernel to verify bf16 device kernel Tensor a_m_k_real_bf16( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor a_m_k_imag_bf16( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor b_k_n_real_bf16( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor b_k_n_imag_bf16( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor c_m_n_real_device_bf16( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_imag_device_bf16( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor a_m_k_real_fp32( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor a_m_k_imag_fp32( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor b_k_n_real_fp32( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor b_k_n_imag_fp32( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor c_m_n_real_host_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_imag_host_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_real_device_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_imag_device_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); a_m_k_real_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); a_m_k_imag_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); b_k_n_real_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); b_k_n_imag_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); bf16_to_f32_(a_m_k_real_bf16, a_m_k_real_fp32); bf16_to_f32_(a_m_k_imag_bf16, a_m_k_imag_fp32); bf16_to_f32_(b_k_n_real_bf16, b_k_n_real_fp32); bf16_to_f32_(b_k_n_imag_bf16, b_k_n_imag_fp32); return std::make_tuple(a_m_k_real_bf16, a_m_k_imag_bf16, b_k_n_real_bf16, b_k_n_imag_bf16, c_m_n_real_device_bf16, c_m_n_imag_device_bf16, a_m_k_real_fp32, a_m_k_imag_fp32, b_k_n_real_fp32, b_k_n_imag_fp32, c_m_n_real_host_fp32, c_m_n_imag_host_fp32, c_m_n_real_device_fp32, c_m_n_imag_device_fp32); } auto operator()(DeviceCGemmPtr_& cgemmPtr) { // Arrange ck::cgemm_util::CGemmParams params; params.M = 1024; params.N = 1024; params.K = 1024; params.StrideA = 1024; params.StrideB = 1024; params.StrideC = 1024; auto host_tensors = PrepareCGemmTensorBF16(params); const Tensor& a_real_bf16 = std::get<0>(host_tensors); const Tensor& a_imag_bf16 = std::get<1>(host_tensors); const Tensor& b_real_bf16 = std::get<2>(host_tensors); const Tensor& b_imag_bf16 = std::get<3>(host_tensors); Tensor& c_real_device_bf16 = std::get<4>(host_tensors); Tensor& c_imag_device_bf16 = std::get<5>(host_tensors); Tensor& a_real_fp32 = std::get<6>(host_tensors); Tensor& a_imag_fp32 = std::get<7>(host_tensors); Tensor& b_real_fp32 = std::get<8>(host_tensors); Tensor& b_imag_fp32 = std::get<9>(host_tensors); Tensor& c_real_host_fp32 = std::get<10>(host_tensors); Tensor& c_imag_host_fp32 = std::get<11>(host_tensors); Tensor& c_real_device_fp32 = std::get<12>(host_tensors); Tensor& c_imag_device_fp32 = std::get<13>(host_tensors); auto a_element_op = AElementwiseOperation{}; auto b_element_op = BElementwiseOperation{}; auto c_element_op = CElementwiseOperation{}; // use fp32 host kernel to verify bf16 device kernel using ReferenceCGemmInstance = ck::tensor_operation::host::ReferenceCGemm; ck::cgemm_util::RunHostCGEMM(a_real_fp32, a_imag_fp32, b_real_fp32, b_imag_fp32, c_real_host_fp32, c_imag_host_fp32, a_element_op, b_element_op, c_element_op); // Act ck::cgemm_util::RunDeviceCGEMM(cgemmPtr, params, a_real_bf16, a_imag_bf16, b_real_bf16, b_imag_bf16, c_real_device_bf16, c_imag_device_bf16, a_element_op, b_element_op, c_element_op); bf16_to_f32_(c_real_device_bf16, c_real_device_fp32); bf16_to_f32_(c_imag_device_bf16, c_imag_device_fp32); // Assert const bool res_real = ck::utils::check_err(c_real_device_fp32.mData, c_real_host_fp32.mData, "Error: incorrect results in real part!", 1e-2f, 1e-1f); const bool res_imag = ck::utils::check_err(c_imag_device_fp32.mData, c_imag_host_fp32.mData, "Error: incorrect results in imaginary part!", 1e-2f, 1e-1f); const bool res = res_real && res_imag; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; return res; }; }; } // namespace cgemm_util } // namespace ck #endif