Commit 149b2f85 authored by rusty1s's avatar rusty1s
Browse files

gputest

parent d9be28c0
...@@ -8,23 +8,23 @@ from torch.utils.ffi import create_extension ...@@ -8,23 +8,23 @@ from torch.utils.ffi import create_extension
if osp.exists('build'): if osp.exists('build'):
shutil.rmtree('build') shutil.rmtree('build')
files = ['serial', 'grid'] files = ['Greedy', 'Grid']
headers = ['torch_cluster/src/{}_cpu.h'.format(f) for f in files] headers = ['aten/TH/TH{}.h'.format(f) for f in files]
sources = ['torch_cluster/src/{}_cpu.c'.format(f) for f in files] sources = ['aten/TH/TH{}.c'.format(f) for f in files]
include_dirs = ['torch_cluster/src', 'aten/TH'] include_dirs = ['aten/TH']
define_macros = [] define_macros = []
extra_objects = [] extra_objects = []
with_cuda = False with_cuda = False
if torch.cuda.is_available(): if torch.cuda.is_available():
subprocess.call(['./build.sh', osp.dirname(torch.__file__)]) subprocess.call(['./build_new.sh', osp.dirname(torch.__file__)])
headers += ['torch_cluster/src/{}_cuda.h'.format(f) for f in files] headers += ['aten/THCC/THCC{}.h'.format(f) for f in files]
sources += ['torch_cluster/src/{}_cuda.c'.format(f) for f in files] sources += ['aten/THCC/THCC{}.c'.format(f) for f in files]
include_dirs += ['torch_cluster/kernel'] include_dirs += ['aten/THC', 'aten/THCC']
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
extra_objects += ['torch_cluster/build/{}.so'.format(f) for f in files] extra_objects += ['aten/build/THC{}.so'.format(f) for f in files]
with_cuda = True with_cuda = True
ffi = create_extension( ffi = create_extension(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment