cmake_minimum_required(VERSION 3.18) project(test_torch_library_expand) # 设置 C++ 标准 set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED TRUE) set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch) # 在文件开头添加 include_directories(/usr/include/python3.10) # 或者更通用的方式 include_directories(/usr/include/python3.10) # 如果有需要,还要链接Python库 link_directories(/usr/lib/python3.10/config-3.10-x86_64-linux-gnu) # 查找并加载 Torch 库 find_package(Torch REQUIRED) # 创建扩展库 add_library(test_ops SHARED test_torch_library_expand.cpp) # 链接 PyTorch target_link_libraries(test_ops PRIVATE ${TORCH_LIBRARIES}) # 设置扩展名称 set_target_properties(test_ops PROPERTIES PREFIX "" SUFFIX ".so" ) # 包含 PyTorch 头文件 target_include_directories(test_ops PRIVATE ${TORCH_INCLUDE_DIRS}) # 设置 CUDA 架构(如果需要) if (TORCH_CUDA_ARCH_LIST) set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST}) else() set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6") endif() # 打印信息 message(STATUS "PyTorch 版本: ${Torch_VERSION}") message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}") if (TORCH_CUDA_AVAILABLE) message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}") message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}") endif()