Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5709cfb5
Commit
5709cfb5
authored
May 01, 2018
by
Christian Sarofeen
Browse files
Try to improve robustness of finding cuda in build. Try to support building with CUDA 8.
parent
1cea1005
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
30 deletions
+69
-30
setup.py
setup.py
+69
-30
No files found.
setup.py
View file @
5709cfb5
...
...
@@ -17,10 +17,16 @@ import ctypes.util
import
torch
#Takes a path to walk
#A function to decide if to keep
#collection if we want a list of all occurances
def
find
(
path
,
regex_func
,
collect
=
False
):
"""
Recursively searches through a directory with regex_func and
either collects all instances or returns the first instance.
Args:
path: Directory to search through
regex_function: A function to run on each file to decide if it should be returned/collected
collect (False) : If True will collect all instances of matching, else will return first instance only
"""
collection
=
[]
if
collect
else
None
for
root
,
dirs
,
files
in
os
.
walk
(
path
):
for
file
in
files
:
...
...
@@ -31,25 +37,55 @@ def find(path, regex_func, collect=False):
return
os
.
path
.
join
(
root
,
file
)
return
list
(
set
(
collection
))
def
findcuda
():
"""
Based on PyTorch build process. Will look for nvcc for compilation.
Either will set cuda home by enviornment variable CUDA_HOME or will search
for nvcc. Returns NVCC executable, cuda major version and cuda home directory.
"""
cuda_path
=
None
CUDA_HOME
=
None
CUDA_HOME
=
os
.
getenv
(
'CUDA_HOME'
,
'/usr/local/cuda'
)
if
not
os
.
path
.
exists
(
CUDA_HOME
):
# We use nvcc path on Linux and cudart path on macOS
osname
=
platform
.
system
()
if
osname
==
'Linux'
:
cuda_path
=
find_nvcc
()
else
:
cudart_path
=
ctypes
.
util
.
find_library
(
'cudart'
)
if
cudart_path
is
not
None
:
cuda_path
=
os
.
path
.
dirname
(
cudart_path
)
else
:
cuda_path
=
None
cudart_path
=
ctypes
.
util
.
find_library
(
'cudart'
)
if
cudart_path
is
not
None
:
cuda_path
=
os
.
path
.
dirname
(
cudart_path
)
if
cuda_path
is
not
None
:
CUDA_HOME
=
os
.
path
.
dirname
(
cuda_path
)
else
:
CUDA_HOME
=
None
WITH_CUDA
=
CUDA_HOME
is
not
None
return
CUDA_HOME
if
not
cuda_path
and
not
CUDA_HOME
:
nvcc_path
=
find
(
'/usr/local/'
,
re
.
compile
(
"nvcc"
).
search
,
False
)
if
nvcc_path
:
CUDA_HOME
=
os
.
path
.
dirname
(
nvcc_path
)
if
CUDA_HOME
:
os
.
path
.
dirname
(
CUDA_HOME
)
if
(
not
os
.
path
.
exists
(
CUDA_HOME
+
os
.
sep
+
"lib64"
)
or
not
os
.
path
.
exists
(
CUDA_HOME
+
os
.
sep
+
"include"
)
):
raise
RuntimeError
(
"Error: found NVCC at "
,
nvcc_path
,
" but could not locate CUDA libraries"
+
" or include directories."
)
raise
RuntimeError
(
"Error: Could not find cuda on this system."
+
" Please set your CUDA_HOME enviornment variable to the CUDA base directory."
)
NVCC
=
find
(
CUDA_HOME
,
re
.
compile
(
'nvcc'
).
search
)
CUDA_LIB
=
find
(
CUDA_HOME
,
re
.
compile
(
'libcudart.so.*.*.*'
).
search
)
if
CUDA_LIB
:
try
:
CUDA_VERSION
=
int
(
CUDA_LIB
.
split
(
'.'
)[
2
])
except
(
ValueError
,
TypeError
):
CUDA_VERSION
=
9
else
:
CUDA_VERSION
=
9
if
CUDA_VERSION
<
8
:
raise
RuntimeError
(
"Error: APEx requires CUDA 8 or newer"
)
return
NVCC
,
CUDA_VERSION
,
CUDA_HOME
#Get some important paths
curdir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
inspect
.
stack
()[
0
][
1
]))
...
...
@@ -87,7 +123,7 @@ extra_compile_args = ["--std=c++11",]
#findcuda returns root dir of CUDA
#include cuda/include and cuda/lib64 for python module build.
CUDA_HOME
=
findcuda
()
NVCC
,
CUDA_VERSION
,
CUDA_HOME
=
findcuda
()
library_dirs
.
append
(
os
.
path
.
join
(
CUDA_HOME
,
"lib64"
))
include_dirs
.
append
(
os
.
path
.
join
(
CUDA_HOME
,
'include'
))
...
...
@@ -107,22 +143,25 @@ class RMBuild(clean):
shutil
.
rmtree
(
eggdir
)
clean
.
run
(
self
)
def
CompileCudaFiles
():
def
CompileCudaFiles
(
NVCC
,
CUDA_VERSION
):
print
()
print
(
"Compiling cuda modules with nvcc:"
)
#Need arches to compile for. Compiles for 70 which requires CUDA9
nvcc_cmd
=
[
'nvcc'
,
'-Xcompiler'
,
'-fPIC'
,
'-gencode'
,
'arch=compute_52,code=sm_52'
,
gencodes
=
[
'-gencode'
,
'arch=compute_52,code=sm_52'
,
'-gencode'
,
'arch=compute_60,code=sm_60'
,
'-gencode'
,
'arch=compute_61,code=sm_61'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-gencode'
,
'arch=compute_70,code=compute_70'
,
'--std=c++11'
,
'-O3'
,
]
'-gencode'
,
'arch=compute_61,code=sm_61'
,]
if
CUDA_VERSION
>
8
:
gencodes
+=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-gencode'
,
'arch=compute_70,code=compute_70'
,]
#Need arches to compile for. Compiles for 70 which requires CUDA9
nvcc_cmd
=
[
NVCC
,
'-Xcompiler'
,
'-fPIC'
]
+
gencodes
+
[
'--std=c++11'
,
'-O3'
,
]
for
dir
in
include_dirs
:
nvcc_cmd
.
append
(
"-I"
+
dir
)
...
...
@@ -152,7 +191,7 @@ if 'clean' not in sys.argv:
print
(
"library_dirs: "
,
library_dirs
)
print
(
"libraries: "
,
main_libraries
)
print
()
CompileCudaFiles
()
CompileCudaFiles
(
NVCC
,
CUDA_VERSION
)
print
(
"Building CUDA extension."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment