cmake_minimum_required(VERSION 3.21 FATAL_ERROR)
project(FastPTCOverheadMRE LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

set(BACKEND "hip" CACHE STRING "Backend: hip or cuda")

execute_process(
    COMMAND python3 -c "import torch; print(torch.utils.cmake_prefix_path)"
    OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH
    OUTPUT_STRIP_TRAILING_WHITESPACE
)
list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX_PATH}")

find_package(Torch REQUIRED)
string(REPLACE "-Wno-duplicate-decl-specifier" "" TORCH_CXX_FLAGS "${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

function(disable_noisy_warnings target_name)
    target_compile_options(${target_name} PRIVATE
        $<$<COMPILE_LANGUAGE:CXX>:-Wno-unused-result>
        $<$<COMPILE_LANGUAGE:HIP>:-Wno-unused-result>
        $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-unused-result>
    )
endfunction()

if(BACKEND STREQUAL "hip")
    enable_language(HIP)
    set_source_files_properties(src/device_query.cpp PROPERTIES LANGUAGE HIP)
    add_executable(device_query src/device_query.cpp)
    add_library(guard_ext SHARED src/guard_ext.cpp)
    target_compile_definitions(device_query PRIVATE BACKEND_HIP=1)
    target_compile_definitions(guard_ext PRIVATE BACKEND_HIP=1)
    target_compile_options(device_query PRIVATE $<$<COMPILE_LANGUAGE:HIP>:-O3>)
elseif(BACKEND STREQUAL "cuda")
    enable_language(CUDA)
    set(CMAKE_INCLUDE_SYSTEM_FLAG_CUDA "-I")
    set(CMAKE_CUDA_STANDARD 17)
    set(CMAKE_CUDA_STANDARD_REQUIRED ON)
    if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
        set(CMAKE_CUDA_ARCHITECTURES "60;70;80;90")
    endif()
    set_source_files_properties(src/device_query.cpp PROPERTIES LANGUAGE CUDA)
    add_executable(device_query src/device_query.cpp)
    add_library(guard_ext SHARED src/guard_ext.cpp)
    target_compile_definitions(device_query PRIVATE BACKEND_CUDA=1)
    target_compile_definitions(guard_ext PRIVATE BACKEND_CUDA=1)
    target_compile_options(device_query PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-O3>)
else()
    message(FATAL_ERROR "BACKEND must be hip or cuda")
endif()

target_link_libraries(guard_ext PRIVATE ${TORCH_LIBRARIES})
disable_noisy_warnings(device_query)
disable_noisy_warnings(guard_ext)

set_target_properties(device_query PROPERTIES
    RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
)
set_target_properties(guard_ext PROPERTIES
    LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
)
