cmake_minimum_required(VERSION 3.18) project(test_torch_library_expand) # 设置 C++ 标准 set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED TRUE) # 查找 Python(更通用的方式) find_package(Python 3.10 COMPONENTS Interpreter Development REQUIRED) set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch) # 查找并加载 Torch 库 find_package(Torch REQUIRED) # 创建扩展库 - 包含多个源文件 add_library(test_ops SHARED test_torch_library_expand.cpp test_ops_impl.cpp ) # 链接 PyTorch 和 Python 库 target_link_libraries(test_ops PRIVATE ${TORCH_LIBRARIES} Python::Python ) # 设置扩展名称 set_target_properties(test_ops PROPERTIES PREFIX "" SUFFIX ".so" ) # 包含头文件 target_include_directories(test_ops PRIVATE ${TORCH_INCLUDE_DIRS} ${Python_INCLUDE_DIRS} ) # 设置 CUDA 架构(如果需要) if (TORCH_CUDA_ARCH_LIST) set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST}) else() set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6") endif() # 打印信息 message(STATUS "PyTorch 版本: ${Torch_VERSION}") message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}") if (TORCH_CUDA_AVAILABLE) message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}") message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}") endif()