Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d0740dff
Unverified
Commit
d0740dff
authored
Oct 14, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 14, 2023
Browse files
Fix error message on `TORCH_CUDA_ARCH_LIST` (#1239)
Co-authored-by:
Yunfeng Bai
<
yunfeng.bai@scale.com
>
parent
de894728
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
12 deletions
+25
-12
setup.py
setup.py
+25
-12
No files found.
setup.py
View file @
d0740dff
...
...
@@ -13,7 +13,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS
=
[
"7.0"
,
"7.5"
,
"8.0"
,
"8.6"
,
"8.9"
,
"9.0"
]
SUPPORTED_ARCHS
=
{
"7.0"
,
"7.5"
,
"8.0"
,
"8.6"
,
"8.9"
,
"9.0"
}
# Compiler flags.
CXX_FLAGS
=
[
"-g"
,
"-O2"
,
"-std=c++17"
]
...
...
@@ -49,19 +49,32 @@ def get_torch_arch_list() -> Set[str]:
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
valid_arch_strs
=
SUPPORTED_ARCHS
+
[
s
+
"+PTX"
for
s
in
SUPPORTED_ARCHS
]
arch_list
=
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
if
arch_list
is
None
:
env_arch_list
=
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
if
env_arch_list
is
None
:
return
set
()
# List are separated by ; or space.
arch_list
=
arch_list
.
replace
(
" "
,
";"
).
split
(
";"
)
for
arch
in
arch_list
:
if
arch
not
in
valid_arch_strs
:
raise
ValueError
(
f
"Unsupported CUDA arch (
{
arch
}
). "
f
"Valid CUDA arch strings are:
{
valid_arch_strs
}
."
)
return
set
(
arch_list
)
torch_arch_list
=
set
(
env_arch_list
.
replace
(
" "
,
";"
).
split
(
";"
))
if
not
torch_arch_list
:
return
set
()
# Filter out the invalid architectures and print a warning.
valid_archs
=
SUPPORTED_ARCHS
.
union
({
s
+
"+PTX"
for
s
in
SUPPORTED_ARCHS
})
arch_list
=
torch_arch_list
.
intersection
(
valid_archs
)
# If none of the specified architectures are valid, raise an error.
if
not
arch_list
:
raise
RuntimeError
(
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
f
"variable (
{
env_arch_list
}
) is supported. "
f
"Supported CUDA architectures are:
{
valid_archs
}
."
)
invalid_arch_list
=
torch_arch_list
-
valid_archs
if
invalid_arch_list
:
warnings
.
warn
(
f
"Unsupported CUDA architectures (
{
invalid_arch_list
}
) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f
"(
{
env_arch_list
}
). Supported CUDA architectures are: "
f
"
{
valid_archs
}
."
)
return
arch_list
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
...
...
@@ -81,7 +94,7 @@ nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if
not
compute_capabilities
:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities
=
set
(
SUPPORTED_ARCHS
)
compute_capabilities
=
SUPPORTED_ARCHS
.
copy
(
)
if
nvcc_cuda_version
<
Version
(
"11.1"
):
compute_capabilities
.
remove
(
"8.6"
)
if
nvcc_cuda_version
<
Version
(
"11.8"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment