Commit 83acda92 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Fix build for pytorch post 0.4

parent cc8f03c8
......@@ -106,7 +106,12 @@ cuda_files = find(curdir, lambda file: file.endswith(".cu"), True)
cuda_headers = find(curdir, lambda file: file.endswith(".cuh"), True)
headers = find(curdir, lambda file: file.endswith(".h"), True)
libaten = find(torch_dir, re.compile("libaten", re.IGNORECASE).search, False)
libaten = list(set(find(torch_dir, re.compile("libaten", re.IGNORECASE).search, True)))
libaten_names = [os.path.splitext(os.path.basename(entry))[0] for entry in libaten]
for i, entry in enumerate(libaten_names):
if entry[:3]=='lib':
libaten_names[i] = entry[3:]
aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False)
include_dirs = [os.path.dirname(os.path.dirname(aten_h))]
......@@ -119,13 +124,13 @@ for file in cuda_headers+headers:
assert libaten, "Could not find PyTorch's libATen."
assert aten_h, "Could not find PyTorch's ATen header."
library_dirs.append(os.path.dirname(libaten))
library_dirs.append(os.path.dirname(libaten[0]))
#create some places to collect important things
object_files = []
extra_link_args=[]
main_libraries = []
main_libraries += ['cudart', 'ATen']
main_libraries += ['cudart',]+libaten_names
extra_compile_args = ["--std=c++11",]
#findcuda returns root dir of CUDA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment