// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck/ck.hpp" #include "ck/stream_config.hpp" #include "ck/host_utility/hip_check_error.hpp" template float launch_and_time_kernel(const StreamConfig& stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { #if CK_TIME_KERNEL if(stream_config.time_kernel_) { #if 0 printf("HipGraph OFF\n"); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, grid_dim.x, grid_dim.y, grid_dim.z, block_dim.x, block_dim.y, block_dim.z); printf("Warm up %d times\n", stream_config.cold_niters_); } // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { kernel<<>>(args...); hip_check_error(hipGetLastError()); } const int nrepeat = stream_config.nrepeat_; if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); hip_check_error(hipEventCreate(&stop)); hip_check_error(hipDeviceSynchronize()); hip_check_error(hipEventRecord(start, stream_config.stream_id_)); for(int i = 0; i < nrepeat; ++i) { kernel<<>>(args...); hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventSynchronize(stop)); float total_time = 0; hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventDestroy(start)); hip_check_error(hipEventDestroy(stop)); return total_time / nrepeat; #elif 1 printf("HipGraph ON\n"); hipGraph_t graph_; hipStream_t stream_; HIP_CHECK_ERROR(hipStreamCreate(&stream_)); StreamConfig sc{stream_}; HIP_CHECK_ERROR(hipStreamBeginCapture(sc.stream_id_, hipStreamCaptureModeGlobal)); for(int i_r = 0; i_r < stream_config.nrepeat_; i_r++) { kernel<<>>(args...); } HIP_CHECK_ERROR(hipStreamEndCapture(sc.stream_id_, &graph_)); hipGraphExec_t instance_; HIP_CHECK_ERROR(hipGraphInstantiate(&instance_, graph_, nullptr, nullptr, 0)); hipEvent_t start_, stop_; HIP_CHECK_ERROR(hipEventCreate(&start_)); HIP_CHECK_ERROR(hipEventCreate(&stop_)); // warm-up for(int i_r = 0; i_r < stream_config.cold_niters_; i_r++) { kernel<<>>(args...); } HIP_CHECK_ERROR(hipDeviceSynchronize()); HIP_CHECK_ERROR(hipEventRecord(start_, sc.stream_id_)); HIP_CHECK_ERROR(hipGraphLaunch(instance_, sc.stream_id_)); HIP_CHECK_ERROR(hipEventRecord(stop_, sc.stream_id_)); HIP_CHECK_ERROR(hipEventSynchronize(stop_)); HIP_CHECK_ERROR(hipGetLastError()); HIP_CHECK_ERROR(hipGraphDestroy(graph_)); float total_time = 0; HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start_, stop_)); return total_time / stream_config.nrepeat_; #endif } else { kernel<<>>(args...); hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); hip_check_error(hipGetLastError()); return 0; #endif } template float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { #if CK_TIME_KERNEL if(stream_config.time_kernel_) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, grid_dim.x, grid_dim.y, grid_dim.z, block_dim.x, block_dim.y, block_dim.z); printf("Warm up %d times\n", stream_config.cold_niters_); } // warm up preprocess(); for(int i = 0; i < stream_config.cold_niters_; ++i) { kernel<<>>(args...); hip_check_error(hipGetLastError()); } const int nrepeat = stream_config.nrepeat_; if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); hip_check_error(hipEventCreate(&stop)); hip_check_error(hipDeviceSynchronize()); hip_check_error(hipEventRecord(start, stream_config.stream_id_)); for(int i = 0; i < nrepeat; ++i) { preprocess(); kernel<<>>(args...); hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventSynchronize(stop)); float total_time = 0; hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventDestroy(start)); hip_check_error(hipEventDestroy(stop)); return total_time / nrepeat; } else { preprocess(); kernel<<>>(args...); hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); hip_check_error(hipGetLastError()); return 0; #endif }