cmake_minimum_required(VERSION 3.18)
project(test_torch_library_expand)

# 设置 C++ 标准
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)

# 查找 Python（更通用的方式）
find_package(Python 3.10 COMPONENTS Interpreter Development REQUIRED)
set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch)

# 查找并加载 Torch 库
find_package(Torch REQUIRED)

# 创建扩展库 - 包含多个源文件
add_library(test_ops SHARED 
    test_torch_library_expand.cpp 
    test_ops_impl.cpp
)

# 链接 PyTorch 和 Python 库
target_link_libraries(test_ops PRIVATE 
    ${TORCH_LIBRARIES}
    Python::Python
)

# 设置扩展名称
set_target_properties(test_ops PROPERTIES
  PREFIX ""
  SUFFIX ".so"
)

# 包含头文件
target_include_directories(test_ops PRIVATE 
    ${TORCH_INCLUDE_DIRS}
    ${Python_INCLUDE_DIRS}
)

# 设置 CUDA 架构（如果需要）
if (TORCH_CUDA_ARCH_LIST)
  set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST})
else()
  set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6")
endif()

# 打印信息
message(STATUS "PyTorch 版本: ${Torch_VERSION}")
message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}")
if (TORCH_CUDA_AVAILABLE)
  message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}")
  message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}")
endif()