torch.cmake 2.24 KB
Newer Older
pkufool's avatar
pkufool committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c)  2021  Xiaomi Corporation (authors: Fangjun Kuang)
# PYTHON_EXECUTABLE is set by pybind11.cmake
message(STATUS "Python executable: ${PYTHON_EXECUTABLE}")
execute_process(
  COMMAND "${PYTHON_EXECUTABLE}" -c "import os; import torch; print(os.path.dirname(torch.__file__))"
  OUTPUT_STRIP_TRAILING_WHITESPACE
  OUTPUT_VARIABLE TORCH_DIR
)

list(APPEND CMAKE_PREFIX_PATH "${TORCH_DIR}")
find_package(Torch REQUIRED)

# set the global CMAKE_CXX_FLAGS so that
# optimized_transducer uses the same abi flag as PyTorch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
if(OT_WITH_CUDA)
  set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${TORCH_CXX_FLAGS}")
endif()


execute_process(
  COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])"
  OUTPUT_STRIP_TRAILING_WHITESPACE
  OUTPUT_VARIABLE OT_TORCH_VERSION_MAJOR
)

execute_process(
  COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])"
  OUTPUT_STRIP_TRAILING_WHITESPACE
  OUTPUT_VARIABLE OT_TORCH_VERSION_MINOR
)

execute_process(
  COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__)"
  OUTPUT_STRIP_TRAILING_WHITESPACE
  OUTPUT_VARIABLE TORCH_VERSION
)

message(STATUS "PyTorch version: ${TORCH_VERSION}")

if(OT_WITH_CUDA)
  execute_process(
    COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.version.cuda)"
    OUTPUT_STRIP_TRAILING_WHITESPACE
    OUTPUT_VARIABLE TORCH_CUDA_VERSION
  )

  message(STATUS "PyTorch cuda version: ${TORCH_CUDA_VERSION}")

  if(NOT CUDA_VERSION VERSION_EQUAL TORCH_CUDA_VERSION)
    message(FATAL_ERROR
      "PyTorch ${TORCH_VERSION} is compiled with CUDA ${TORCH_CUDA_VERSION}.\n"
      "But you are using CUDA ${CUDA_VERSION} to compile optimized_transducer.\n"
      "Please try to use the same CUDA version for PyTorch and optimized_transducer.\n"
      "**You can remove this check if you are sure this will not cause "
      "problems**\n"
    )
  endif()

# Solve the following error for NVCC:
#   unknown option `-Wall`
#
# It contains only some -Wno-* flags, so it is OK
# to set them to empty
  set_property(TARGET torch_cuda
    PROPERTY
      INTERFACE_COMPILE_OPTIONS ""
  )
  set_property(TARGET torch_cpu
    PROPERTY
      INTERFACE_COMPILE_OPTIONS ""
  )
endif()