#include #include #include #include #include #include #include #include "gemm_util.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" #include "test_util.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceGemmPtr_ = ck::tensor_operation::device::DeviceGemmPtr; namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector&); } } // namespace device } // namespace tensor_operation } // namespace ck namespace { using ADataType = int8_t; using BDataType = int8_t; using CDataType = int8_t; using AccDataType = int32_t; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); } else { return HostTensorDescriptor(std::vector({row, col}), std::vector({1, stride})); } }; Tensor a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor b_k_n( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor c_m_n_host_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor c_m_n_device_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); } bool TestGemm(DeviceGemmPtr_& gemmPtr) { // Arrange ck::gemm_util::GemmParams params; params.M = 1024; params.N = 1024; params.K = 1024; params.StrideA = 1024; params.StrideB = 1024; params.StrideC = 1024; auto host_tensors = PrepareGemmTensor(params); const Tensor& a = std::get<0>(host_tensors); const Tensor& b = std::get<1>(host_tensors); Tensor& c_host = std::get<2>(host_tensors); Tensor& c_device = std::get<3>(host_tensors); auto a_element_op = PassThrough{}; auto b_element_op = PassThrough{}; auto c_element_op = PassThrough{}; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; ck::gemm_util::RunHostGEMM( a, b, c_host, a_element_op, b_element_op, c_element_op); // Act ck::gemm_util::RunDeviceGEMM( gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); // Assert bool res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; return res; } } // anonymous namespace int main() { std::vector gemmPtrs; ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); bool res = true; for(auto& gemmPtr : gemmPtrs) { res &= TestGemm(gemmPtr); } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; }