Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a425bd9a
Unverified
Commit
a425bd9a
authored
Sep 26, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 26, 2023
Browse files
[Setup] Enable `TORCH_CUDA_ARCH_LIST` for selecting target GPUs (#1074)
parent
bbbf8656
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
38 deletions
+73
-38
setup.py
setup.py
+73
-38
No files found.
setup.py
View file @
a425bd9a
...
...
@@ -3,6 +3,7 @@ import os
import
re
import
subprocess
from
typing
import
List
,
Set
import
warnings
from
packaging.version
import
parse
,
Version
import
setuptools
...
...
@@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS
=
[
"7.0"
,
"7.5"
,
"8.0"
,
"8.6"
,
"8.9"
,
"9.0"
]
# Compiler flags.
CXX_FLAGS
=
[
"-g"
,
"-O2"
,
"-std=c++17"
]
# TODO(woosuk): Should we use -O3?
...
...
@@ -38,51 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return
nvcc_cuda_version
# Collect the compute capabilities of all available GPUs.
device_count
=
torch
.
cuda
.
device_count
()
compute_capabilities
:
Set
[
int
]
=
set
()
for
i
in
range
(
device_count
):
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
if
major
<
7
:
raise
RuntimeError
(
"GPUs with compute capability less than 7.0 are not supported."
)
compute_capabilities
.
add
(
major
*
10
+
minor
)
def
get_torch_arch_list
()
->
Set
[
str
]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
valid_arch_strs
=
SUPPORTED_ARCHS
+
[
s
+
"+PTX"
for
s
in
SUPPORTED_ARCHS
]
arch_list
=
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
if
arch_list
is
None
:
return
set
()
# List are separated by ; or space.
arch_list
=
arch_list
.
replace
(
" "
,
";"
).
split
(
";"
)
for
arch
in
arch_list
:
if
arch
not
in
valid_arch_strs
:
raise
ValueError
(
f
"Unsupported CUDA arch (
{
arch
}
). "
f
"Valid CUDA arch strings are:
{
valid_arch_strs
}
."
)
return
set
(
arch_list
)
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities
=
get_torch_arch_list
()
if
not
compute_capabilities
:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count
=
torch
.
cuda
.
device_count
()
for
i
in
range
(
device_count
):
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
if
major
<
7
:
raise
RuntimeError
(
"GPUs with compute capability below 7.0 are not supported."
)
compute_capabilities
.
add
(
f
"
{
major
}
.
{
minor
}
"
)
# Validate the NVCC CUDA version.
nvcc_cuda_version
=
get_nvcc_cuda_version
(
CUDA_HOME
)
if
not
compute_capabilities
:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities
=
set
(
SUPPORTED_ARCHS
)
if
nvcc_cuda_version
<
Version
(
"11.1"
):
compute_capabilities
.
remove
(
"8.6"
)
if
nvcc_cuda_version
<
Version
(
"11.8"
):
compute_capabilities
.
remove
(
"8.9"
)
compute_capabilities
.
remove
(
"9.0"
)
# Validate the NVCC CUDA version.
if
nvcc_cuda_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"CUDA 11.0 or higher is required to build the package."
)
if
86
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.1"
):
raise
RuntimeError
(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if
89
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
compute_capabilities
.
remove
(
89
)
compute_capabilities
.
add
(
80
)
if
90
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
RuntimeError
(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
# If no GPU is available, add all supported compute capabilities.
if
not
compute_capabilities
:
compute_capabilities
=
{
70
,
75
,
80
}
if
nvcc_cuda_version
>=
Version
(
"11.1"
):
compute_capabilities
.
add
(
86
)
if
nvcc_cuda_version
>=
Version
(
"11.8"
):
compute_capabilities
.
add
(
89
)
compute_capabilities
.
add
(
90
)
if
nvcc_cuda_version
<
Version
(
"11.1"
):
if
any
(
cc
.
startswith
(
"8.6"
)
for
cc
in
compute_capabilities
):
raise
RuntimeError
(
"CUDA 11.1 or higher is required for compute capability 8.6."
)
if
nvcc_cuda_version
<
Version
(
"11.8"
):
if
any
(
cc
.
startswith
(
"8.9"
)
for
cc
in
compute_capabilities
):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
warnings
.
warn
(
"CUDA 11.8 or higher is required for compute capability 8.9. "
"Targeting compute capability 8.0 instead."
)
compute_capabilities
=
set
(
cc
for
cc
in
compute_capabilities
if
not
cc
.
startswith
(
"8.9"
))
compute_capabilities
.
add
(
"8.0+PTX"
)
if
any
(
cc
.
startswith
(
"9.0"
)
for
cc
in
compute_capabilities
):
raise
RuntimeError
(
"CUDA 11.8 or higher is required for compute capability 9.0."
)
# Add target compute capabilities to NVCC flags.
for
capability
in
compute_capabilities
:
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
capability
}
,code=sm_
{
capability
}
"
]
num
=
capability
[
0
]
+
capability
[
2
]
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
num
}
,code=sm_
{
num
}
"
]
if
capability
.
endswith
(
"+PTX"
):
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
num
}
,code=compute_
{
num
}
"
]
# Use NVCC threads to parallelize the build.
if
nvcc_cuda_version
>=
Version
(
"11.2"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment