Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5b53121a
Unverified
Commit
5b53121a
authored
Aug 01, 2020
by
ptrblck
Committed by
GitHub
Aug 01, 2020
Browse files
Add sm80 for CUDA >= 11 (#925)
parent
700d6825
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
11 deletions
+40
-11
setup.py
setup.py
+40
-11
No files found.
setup.py
View file @
5b53121a
import
torch
from
torch.utils
import
cpp_extension
from
setuptools
import
setup
,
find_packages
import
subprocess
...
...
@@ -9,6 +10,16 @@ import os
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
if
not
torch
.
cuda
.
is_available
():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...
...
@@ -16,11 +27,16 @@ if not torch.cuda.is_available():
print
(
'
\n
Warning: Torch did not find available GPUs on this system.
\n
'
,
'If your intention is to cross-compile, this is not an error.
\n
'
'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),
\n
'
'Volta (compute capability 7.0), and Turing (compute capability 7.5).
\n
'
'Volta (compute capability 7.0), Turing (compute capability 7.5),
\n
'
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).
\n
'
'If you wish to cross-compile for a single specific architecture,
\n
'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
==
11
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
...
...
@@ -64,13 +80,18 @@ if "--cpp_ext" in sys.argv:
CppExtension
(
'apex_C'
,
[
'csrc/flatten_unflatten.cpp'
,]))
def
check
_cuda_
torch_binary_vs_
bare_metal
(
cuda_dir
):
def
get
_cuda_bare_metal
_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
...
...
@@ -85,6 +106,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
...
...
@@ -256,6 +278,13 @@ torch_dir = torch.__path__[0]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
'include'
,
'ATen'
,
'CUDAGenerator.h'
)):
generator_flag
=
[
'-DOLD_GENERATOR'
]
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
if
"--fast_multihead_attn"
in
sys
.
argv
:
from
torch.utils.cpp_extension
import
CUDAExtension
sys
.
argv
.
remove
(
"--fast_multihead_attn"
)
...
...
@@ -279,7 +308,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_mask_softmax_dropout'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp'
,
...
...
@@ -292,7 +321,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn_bias_additive_mask'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp'
,
...
...
@@ -305,7 +334,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn_bias'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp'
,
...
...
@@ -318,7 +347,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp'
,
...
...
@@ -331,7 +360,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn_norm_add'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp'
,
...
...
@@ -344,7 +373,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_encdec_multihead_attn'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp'
,
...
...
@@ -357,7 +386,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_encdec_multihead_attn_norm_add'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp'
,
...
...
@@ -370,7 +399,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
}))
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
setup
(
name
=
'apex'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment