build.py 1.25 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import os.path as osp
rusty1s's avatar
rusty1s committed
2
import shutil
rusty1s's avatar
rusty1s committed
3
4
5
6
7
import subprocess

import torch
from torch.utils.ffi import create_extension

rusty1s's avatar
rusty1s committed
8
if osp.exists('build'):
rusty1s's avatar
rusty1s committed
9
10
    shutil.rmtree('build')

rusty1s's avatar
rusty1s committed
11
12
13
files = ['serial', 'grid']

headers = ['torch_cluster/src/{}_cpu.h'.format(f) for f in files]
rusty1s's avatar
rusty1s committed
14
headers += ['aten/TH/THGreedy.h', 'aten/TH/THGrid.h']
rusty1s's avatar
rusty1s committed
15
sources = ['torch_cluster/src/{}_cpu.c'.format(f) for f in files]
rusty1s's avatar
rusty1s committed
16
17
sources += ['aten/TH/THGreedy.c', 'aten/TH/THGrid.c']
include_dirs = ['torch_cluster/src', 'aten/TH']
rusty1s's avatar
rusty1s committed
18
19
20
21
22
define_macros = []
extra_objects = []
with_cuda = False

if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
23
    subprocess.call(['./build.sh', osp.dirname(torch.__file__)])
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
    headers += ['torch_cluster/src/{}_cuda.h'.format(f) for f in files]
    sources += ['torch_cluster/src/{}_cuda.c'.format(f) for f in files]
rusty1s's avatar
rusty1s committed
27
28
    include_dirs += ['torch_cluster/kernel']
    define_macros += [('WITH_CUDA', None)]
rusty1s's avatar
rusty1s committed
29
    extra_objects += ['torch_cluster/build/{}.so'.format(f) for f in files]
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    with_cuda = True

ffi = create_extension(
    name='torch_cluster._ext.ffi',
    package=True,
    headers=headers,
    sources=sources,
    include_dirs=include_dirs,
    define_macros=define_macros,
    extra_objects=extra_objects,
    with_cuda=with_cuda,
    relative_to=__file__)

if __name__ == '__main__':
    ffi.build()