Unverified Commit 8cbd56c2 authored by Nikita Shulga's avatar Nikita Shulga Committed by GitHub
Browse files

setup.py should parse TORCH_CUDA_ARCH_LIST (#1733)

Needed to support CUDA builds on CPU machine

Parse `TORCH_CUDA_ARCH_LIST` as new-CUDA-language Cmake-3.18+ style [CMAKE_CUDA_ARCHITECTURES](https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html#prop_tgt:CUDA_ARCHITECTURES)
parent f13f211b
cmake_minimum_required(VERSION 3.5 FATAL_ERROR) cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
# Most of the configurations are taken from PyTorch # Most of the configurations are taken from PyTorch
# https://github.com/pytorch/pytorch/blob/0c9fb4aff0d60eaadb04e4d5d099fb1e1d5701a9/CMakeLists.txt # https://github.com/pytorch/pytorch/blob/0c9fb4aff0d60eaadb04e4d5d099fb1e1d5701a9/CMakeLists.txt
......
...@@ -39,6 +39,7 @@ _BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KA ...@@ -39,6 +39,7 @@ _BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KA
_BUILD_RNNT = _get_build("BUILD_RNNT", True) _BUILD_RNNT = _get_build("BUILD_RNNT", True)
_USE_ROCM = _get_build("USE_ROCM") _USE_ROCM = _get_build("USE_ROCM")
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available()) _USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available())
_TORCH_CUDA_ARCH_LIST = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
def get_ext_modules(): def get_ext_modules():
...@@ -82,6 +83,13 @@ class CMakeBuild(build_ext): ...@@ -82,6 +83,13 @@ class CMakeBuild(build_ext):
build_args = [ build_args = [
'--target', 'install' '--target', 'install'
] ]
# Pass CUDA architecture to cmake
if _TORCH_CUDA_ARCH_LIST is not None:
# Convert MAJOR.MINOR[+PTX] list to new style one
# defined at https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html
_arches = _TORCH_CUDA_ARCH_LIST.replace('.', '').split(";")
_arches = [arch[:-4] if arch.endswith("+PTX") else f"{arch}-real" for arch in _arches]
cmake_args += [f"-DCMAKE_CUDA_ARCHITECTURES={';'.join(_arches)}"]
# Default to Ninja # Default to Ninja
if 'CMAKE_GENERATOR' not in os.environ or platform.system() == 'Windows': if 'CMAKE_GENERATOR' not in os.environ or platform.system() == 'Windows':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment