// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { namespace profiler { template // assume Ds and E have same layout bool profile_gemm_bilinear_impl(int do_verification, int init_method, bool /*do_log*/, bool time_kernel, int M, int N, int K, int StrideA, int StrideB, int StrideD, int StrideE, float alpha, float beta) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); } else { return HostTensorDescriptor(std::vector({row, col}), std::vector({1, stride})); } }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; switch(init_method) { case 0: break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = Bilinear; const auto a_element_op = AElementOp{}; const auto b_element_op = BElementOp{}; const auto cde_element_op = CDEElementOp{alpha, beta}; using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< ALayout, BLayout, DELayout, ADataType, BDataType, ck::Tuple, EDataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::Bilinear>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); std::cout << "found " << op_ptrs.size() << " instances" << std::endl; // run reference if(do_verification) { Tensor c_m_n(HostTensorDescriptor( std::vector{static_cast(M), static_cast(N)})); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); ref_invoker.Run(ref_argument); for(int m = 0; m < M; ++m) { for(int n = 0; n < N; ++n) { cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); } } } DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); d_m_n_device_buf.ToDevice(d_m_n.mData.data()); std::string best_op_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; bool pass = true; // profile device operation instances for(auto& op_ptr : op_ptrs) { auto argument_ptr = op_ptr->MakeArgumentPointer( a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), std::array{d_m_n_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), M, N, K, StrideA, StrideB, std::array{StrideD}, StrideE, a_element_op, b_element_op, cde_element_op); auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); if(op_ptr->IsSupportedArgument(argument_ptr.get())) { // re-init E to zero before profiling a kernel e_device_buf.SetZero(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; if(tflops > best_tflops) { best_op_name = op_name; best_tflops = tflops; best_ave_time = ave_time; best_gb_per_sec = gb_per_sec; } if(do_verification) { e_device_buf.FromDevice(e_m_n_device_result.mData.data()); pass = pass && ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData); } } else { std::cout << op_name << " does not support this problem" << std::endl; } } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; return pass; } } // namespace profiler } // namespace ck