CMakeLists.txt 1.27 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
cmake_minimum_required(VERSION 3.18)
project(test_torch_library_expand)

# 设置 C++ 标准
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)

# 查找 Python(更通用的方式)
find_package(Python 3.10 COMPONENTS Interpreter Development REQUIRED)
set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch)

# 查找并加载 Torch 库
find_package(Torch REQUIRED)

# 创建扩展库 - 包含多个源文件
add_library(test_ops SHARED 
    test_torch_library_expand.cpp 
    test_ops_impl.cpp
)

# 链接 PyTorch 和 Python 库
target_link_libraries(test_ops PRIVATE 
    ${TORCH_LIBRARIES}
    Python::Python
)

# 设置扩展名称
set_target_properties(test_ops PROPERTIES
  PREFIX ""
  SUFFIX ".so"
)

# 包含头文件
target_include_directories(test_ops PRIVATE 
    ${TORCH_INCLUDE_DIRS}
    ${Python_INCLUDE_DIRS}
)

# 设置 CUDA 架构(如果需要)
if (TORCH_CUDA_ARCH_LIST)
  set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST})
else()
  set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6")
endif()

# 打印信息
message(STATUS "PyTorch 版本: ${Torch_VERSION}")
message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}")
if (TORCH_CUDA_AVAILABLE)
  message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}")
  message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}")
endif()