Commit 407f53e2 authored by Romeo Valentin's avatar Romeo Valentin
Browse files

Let torch determine correct cuda architecture

See `pytorch/torch/utils/cpp_extension.cpp:CUDAExtension`:
>   By default the extension will be compiled to run on all archs of the cards visible during the
>   building process of the extension, plus PTX. If down the road a new card is installed the
>   extension may need to be recompiled. If a visible card has a compute capability (CC) that's
>   newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
>   will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
>   support (see below for details on PTX).

>   You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which
>   CCs you want the extension to support:

>   TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py
>   TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py

>   The +PTX option causes extension kernel binaries to include PTX instructions for the specified
>   CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >=
>   the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with
>   CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to
>   provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on
>   those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better
>   off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6,
>   "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but
>   "8.0 8.6" would be better.

>   Note that while it's possible to include all supported archs, the more archs get included the
>   slower the building process will be, as it will build a separate kernel image for each arch.
parent 5f4f9c55
...@@ -9,7 +9,6 @@ if(WITH_CUDA) ...@@ -9,7 +9,6 @@ if(WITH_CUDA)
enable_language(CUDA) enable_language(CUDA)
add_definitions(-D__CUDA_NO_HALF_OPERATORS__) add_definitions(-D__CUDA_NO_HALF_OPERATORS__)
add_definitions(-DWITH_CUDA) add_definitions(-DWITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=sm_35 --expt-relaxed-constexpr")
endif() endif()
find_package(Python3 COMPONENTS Development) find_package(Python3 COMPONENTS Development)
......
...@@ -62,7 +62,7 @@ def get_extensions(): ...@@ -62,7 +62,7 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr', '-O2'] nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
if sys.platform == 'win32': if sys.platform == 'win32':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment