cmake_minimum_required(VERSION 3.18)
project(test_torch_library_expand)

# 设置 C++ 标准
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)

# 查找 Python（更通用的方式）
find_package(Python 3.10 COMPONENTS Interpreter Development REQUIRED)
set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch)

# 查找并加载 Torch 库
find_package(Torch REQUIRED)

# 创建扩展库 - 包含多个源文件
add_library(_C SHARED 
    test_torch_library_expand.cpp 
    test_ops_impl.cpp
)

# 设置C++标准
set_property(TARGET _C PROPERTY CXX_STANDARD 17)

# 链接PyTorch库
target_link_libraries(_C ${TORCH_LIBRARIES})

# 设置Python模块的命名和属性
set_target_properties(_C PROPERTIES
    PREFIX ""
    SUFFIX ".so"
    LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
)


# 包含头文件
target_include_directories(_C PRIVATE 
    ${TORCH_INCLUDE_DIRS}
    ${Python_INCLUDE_DIRS}
)

# 设置 CUDA 架构（如果需要）
if (TORCH_CUDA_ARCH_LIST)
  set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST})
else()
  set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6")
endif()

# 打印信息
message(STATUS "PyTorch 版本: ${Torch_VERSION}")
message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}")
if (TORCH_CUDA_AVAILABLE)
  message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}")
  message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}")
endif()