Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
98fa5a3b
Commit
98fa5a3b
authored
Jun 14, 2018
by
Michael Carilli
Browse files
Pulling in old logic to manually look for CUDA_HOME in Pytorch <= 0.4 to allow cross-compilation
parent
0f703d13
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
0 deletions
+44
-0
setup.py
setup.py
+44
-0
No files found.
setup.py
View file @
98fa5a3b
...
@@ -36,6 +36,49 @@ def find(path, regex_func, collect=False):
...
@@ -36,6 +36,49 @@ def find(path, regex_func, collect=False):
return
os
.
path
.
join
(
root
,
file
)
return
os
.
path
.
join
(
root
,
file
)
return
list
(
set
(
collection
))
return
list
(
set
(
collection
))
# Due to https://github.com/pytorch/pytorch/issues/8223, for Pytorch <= 0.4
# torch.utils.cpp_extension's check for CUDA_HOME fails if there are no GPUs
# available on the system, which prevents cross-compiling and building via Dockerfiles.
# Workaround: manually search for CUDA_HOME if Pytorch <= 0.4.
def
find_cuda_home
():
cuda_path
=
None
CUDA_HOME
=
None
CUDA_HOME
=
os
.
getenv
(
'CUDA_HOME'
,
'/usr/local/cuda'
)
if
not
os
.
path
.
exists
(
CUDA_HOME
):
# We use nvcc path on Linux and cudart path on macOS
cudart_path
=
ctypes
.
util
.
find_library
(
'cudart'
)
if
cudart_path
is
not
None
:
cuda_path
=
os
.
path
.
dirname
(
cudart_path
)
if
cuda_path
is
not
None
:
CUDA_HOME
=
os
.
path
.
dirname
(
cuda_path
)
if
not
cuda_path
and
not
CUDA_HOME
:
nvcc_path
=
find
(
'/usr/local/'
,
re
.
compile
(
"nvcc"
).
search
,
False
)
if
nvcc_path
:
CUDA_HOME
=
os
.
path
.
dirname
(
nvcc_path
)
if
CUDA_HOME
:
os
.
path
.
dirname
(
CUDA_HOME
)
if
(
not
os
.
path
.
exists
(
CUDA_HOME
+
os
.
sep
+
"lib64"
)
or
not
os
.
path
.
exists
(
CUDA_HOME
+
os
.
sep
+
"include"
)
):
raise
RuntimeError
(
"Error: found NVCC at "
,
nvcc_path
,
" but could not locate CUDA libraries"
+
" or include directories."
)
raise
RuntimeError
(
"Error: Could not find cuda on this system. "
+
"Please set your CUDA_HOME enviornment variable "
"to the CUDA base directory."
)
return
CUDA_HOME
if
TORCH_MAJOR
==
0
and
TORCH_MINOR
==
4
:
if
CUDA_HOME
is
None
:
CUDA_HOME
=
find_cuda_home
()
# Patch cpp_extension's view of CUDA_HOME:
torch
.
utils
.
cpp_extension
.
CUDA_HOME
=
CUDA_HOME
def
get_cuda_version
():
def
get_cuda_version
():
NVCC
=
find
(
CUDA_HOME
+
os
.
sep
+
"bin"
,
NVCC
=
find
(
CUDA_HOME
+
os
.
sep
+
"bin"
,
re
.
compile
(
'nvcc$'
).
search
)
re
.
compile
(
'nvcc$'
).
search
)
...
@@ -55,6 +98,7 @@ def get_cuda_version():
...
@@ -55,6 +98,7 @@ def get_cuda_version():
return
CUDA_MAJOR
return
CUDA_MAJOR
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
print
(
"Found CUDA_HOME = "
,
CUDA_HOME
)
print
(
"Found CUDA_HOME = "
,
CUDA_HOME
)
CUDA_MAJOR
=
get_cuda_version
()
CUDA_MAJOR
=
get_cuda_version
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment