Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5709cfb5
Commit
5709cfb5
authored
May 01, 2018
by
Christian Sarofeen
Browse files
Try to improve robustness of finding cuda in build. Try to support building with CUDA 8.
parent
1cea1005
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
30 deletions
+69
-30
setup.py
setup.py
+69
-30
No files found.
setup.py
View file @
5709cfb5
...
@@ -17,10 +17,16 @@ import ctypes.util
...
@@ -17,10 +17,16 @@ import ctypes.util
import
torch
import
torch
#Takes a path to walk
#A function to decide if to keep
#collection if we want a list of all occurances
def
find
(
path
,
regex_func
,
collect
=
False
):
def
find
(
path
,
regex_func
,
collect
=
False
):
"""
Recursively searches through a directory with regex_func and
either collects all instances or returns the first instance.
Args:
path: Directory to search through
regex_function: A function to run on each file to decide if it should be returned/collected
collect (False) : If True will collect all instances of matching, else will return first instance only
"""
collection
=
[]
if
collect
else
None
collection
=
[]
if
collect
else
None
for
root
,
dirs
,
files
in
os
.
walk
(
path
):
for
root
,
dirs
,
files
in
os
.
walk
(
path
):
for
file
in
files
:
for
file
in
files
:
...
@@ -31,25 +37,55 @@ def find(path, regex_func, collect=False):
...
@@ -31,25 +37,55 @@ def find(path, regex_func, collect=False):
return
os
.
path
.
join
(
root
,
file
)
return
os
.
path
.
join
(
root
,
file
)
return
list
(
set
(
collection
))
return
list
(
set
(
collection
))
def
findcuda
():
def
findcuda
():
"""
Based on PyTorch build process. Will look for nvcc for compilation.
Either will set cuda home by enviornment variable CUDA_HOME or will search
for nvcc. Returns NVCC executable, cuda major version and cuda home directory.
"""
cuda_path
=
None
CUDA_HOME
=
None
CUDA_HOME
=
os
.
getenv
(
'CUDA_HOME'
,
'/usr/local/cuda'
)
CUDA_HOME
=
os
.
getenv
(
'CUDA_HOME'
,
'/usr/local/cuda'
)
if
not
os
.
path
.
exists
(
CUDA_HOME
):
if
not
os
.
path
.
exists
(
CUDA_HOME
):
# We use nvcc path on Linux and cudart path on macOS
# We use nvcc path on Linux and cudart path on macOS
osname
=
platform
.
system
()
if
osname
==
'Linux'
:
cuda_path
=
find_nvcc
()
else
:
cudart_path
=
ctypes
.
util
.
find_library
(
'cudart'
)
cudart_path
=
ctypes
.
util
.
find_library
(
'cudart'
)
if
cudart_path
is
not
None
:
if
cudart_path
is
not
None
:
cuda_path
=
os
.
path
.
dirname
(
cudart_path
)
cuda_path
=
os
.
path
.
dirname
(
cudart_path
)
else
:
cuda_path
=
None
if
cuda_path
is
not
None
:
if
cuda_path
is
not
None
:
CUDA_HOME
=
os
.
path
.
dirname
(
cuda_path
)
CUDA_HOME
=
os
.
path
.
dirname
(
cuda_path
)
if
not
cuda_path
and
not
CUDA_HOME
:
nvcc_path
=
find
(
'/usr/local/'
,
re
.
compile
(
"nvcc"
).
search
,
False
)
if
nvcc_path
:
CUDA_HOME
=
os
.
path
.
dirname
(
nvcc_path
)
if
CUDA_HOME
:
os
.
path
.
dirname
(
CUDA_HOME
)
if
(
not
os
.
path
.
exists
(
CUDA_HOME
+
os
.
sep
+
"lib64"
)
or
not
os
.
path
.
exists
(
CUDA_HOME
+
os
.
sep
+
"include"
)
):
raise
RuntimeError
(
"Error: found NVCC at "
,
nvcc_path
,
" but could not locate CUDA libraries"
+
" or include directories."
)
raise
RuntimeError
(
"Error: Could not find cuda on this system."
+
" Please set your CUDA_HOME enviornment variable to the CUDA base directory."
)
NVCC
=
find
(
CUDA_HOME
,
re
.
compile
(
'nvcc'
).
search
)
CUDA_LIB
=
find
(
CUDA_HOME
,
re
.
compile
(
'libcudart.so.*.*.*'
).
search
)
if
CUDA_LIB
:
try
:
CUDA_VERSION
=
int
(
CUDA_LIB
.
split
(
'.'
)[
2
])
except
(
ValueError
,
TypeError
):
CUDA_VERSION
=
9
else
:
else
:
CUDA_HOME
=
None
CUDA_VERSION
=
9
WITH_CUDA
=
CUDA_HOME
is
not
None
return
CUDA_HOME
if
CUDA_VERSION
<
8
:
raise
RuntimeError
(
"Error: APEx requires CUDA 8 or newer"
)
return
NVCC
,
CUDA_VERSION
,
CUDA_HOME
#Get some important paths
#Get some important paths
curdir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
inspect
.
stack
()[
0
][
1
]))
curdir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
inspect
.
stack
()[
0
][
1
]))
...
@@ -87,7 +123,7 @@ extra_compile_args = ["--std=c++11",]
...
@@ -87,7 +123,7 @@ extra_compile_args = ["--std=c++11",]
#findcuda returns root dir of CUDA
#findcuda returns root dir of CUDA
#include cuda/include and cuda/lib64 for python module build.
#include cuda/include and cuda/lib64 for python module build.
CUDA_HOME
=
findcuda
()
NVCC
,
CUDA_VERSION
,
CUDA_HOME
=
findcuda
()
library_dirs
.
append
(
os
.
path
.
join
(
CUDA_HOME
,
"lib64"
))
library_dirs
.
append
(
os
.
path
.
join
(
CUDA_HOME
,
"lib64"
))
include_dirs
.
append
(
os
.
path
.
join
(
CUDA_HOME
,
'include'
))
include_dirs
.
append
(
os
.
path
.
join
(
CUDA_HOME
,
'include'
))
...
@@ -107,19 +143,22 @@ class RMBuild(clean):
...
@@ -107,19 +143,22 @@ class RMBuild(clean):
shutil
.
rmtree
(
eggdir
)
shutil
.
rmtree
(
eggdir
)
clean
.
run
(
self
)
clean
.
run
(
self
)
def
CompileCudaFiles
():
def
CompileCudaFiles
(
NVCC
,
CUDA_VERSION
):
print
()
print
()
print
(
"Compiling cuda modules with nvcc:"
)
print
(
"Compiling cuda modules with nvcc:"
)
gencodes
=
[
'-gencode'
,
'arch=compute_52,code=sm_52'
,
'-gencode'
,
'arch=compute_60,code=sm_60'
,
'-gencode'
,
'arch=compute_61,code=sm_61'
,]
if
CUDA_VERSION
>
8
:
gencodes
+=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-gencode'
,
'arch=compute_70,code=compute_70'
,]
#Need arches to compile for. Compiles for 70 which requires CUDA9
#Need arches to compile for. Compiles for 70 which requires CUDA9
nvcc_cmd
=
[
'nvcc'
,
nvcc_cmd
=
[
NVCC
,
'-Xcompiler'
,
'-Xcompiler'
,
'-fPIC'
,
'-fPIC'
'-gencode'
,
'arch=compute_52,code=sm_52'
,
]
+
gencodes
+
[
'-gencode'
,
'arch=compute_60,code=sm_60'
,
'-gencode'
,
'arch=compute_61,code=sm_61'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-gencode'
,
'arch=compute_70,code=compute_70'
,
'--std=c++11'
,
'--std=c++11'
,
'-O3'
,
'-O3'
,
]
]
...
@@ -152,7 +191,7 @@ if 'clean' not in sys.argv:
...
@@ -152,7 +191,7 @@ if 'clean' not in sys.argv:
print
(
"library_dirs: "
,
library_dirs
)
print
(
"library_dirs: "
,
library_dirs
)
print
(
"libraries: "
,
main_libraries
)
print
(
"libraries: "
,
main_libraries
)
print
()
print
()
CompileCudaFiles
()
CompileCudaFiles
(
NVCC
,
CUDA_VERSION
)
print
(
"Building CUDA extension."
)
print
(
"Building CUDA extension."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment