Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
bc64ee83
Commit
bc64ee83
authored
Sep 07, 2022
by
hubertlu-tw
Browse files
Keep --peer_memory and --nccl_p2p CUDA-compatible
parent
fd0f7631
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
0 deletions
+57
-0
setup.py
setup.py
+57
-0
No files found.
setup.py
View file @
bc64ee83
...
@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
return
raw_output
,
bare_metal_major
,
bare_metal_minor
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_minor
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
raise
RuntimeError
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
def
check_cudnn_version_and_warn
(
global_option
:
str
,
required_cudnn_version
:
int
)
->
bool
:
cudnn_available
=
torch
.
backends
.
cudnn
.
is_available
()
cudnn_version
=
torch
.
backends
.
cudnn
.
version
()
if
cudnn_available
else
None
if
not
(
cudnn_available
and
(
cudnn_version
>=
required_cudnn_version
)):
warnings
.
warn
(
f
"Skip `
{
global_option
}
` as it requires cuDNN
{
required_cudnn_version
}
or later, "
f
"but
{
'cuDNN is not available'
if
not
cudnn_available
else
cudnn_version
}
"
)
return
False
return
True
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
...
@@ -539,6 +588,10 @@ if "--fast_bottleneck" in sys.argv:
...
@@ -539,6 +588,10 @@ if "--fast_bottleneck" in sys.argv:
if
"--peer_memory"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
:
if
"--peer_memory"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--peer_memory"
)
sys
.
argv
.
remove
(
"--peer_memory"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--peer_memory"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"peer_memory_cuda"
,
name
=
"peer_memory_cuda"
,
...
@@ -553,6 +606,10 @@ if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
...
@@ -553,6 +606,10 @@ if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
if
"--nccl_p2p"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
:
if
"--nccl_p2p"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--nccl_p2p"
)
sys
.
argv
.
remove
(
"--nccl_p2p"
)
if
not
IS_ROCM_PYTORCH
:
raise_if_cuda_home_none
(
"--nccl_p2p"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"nccl_p2p_cuda"
,
name
=
"nccl_p2p_cuda"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment