build.py 846 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import subprocess

import torch
from torch.utils.ffi import create_extension

headers = []
sources = []
include_dirs = ['torch_cluster/src']
define_macros = []
extra_objects = []
with_cuda = False

if torch.cuda.is_available():
    subprocess.call('./build.sh')  # Compile kernel.

    headers += ['torch_cluster/src/cuda.h']
    sources += ['torch_cluster/src/cuda.c']
    include_dirs += ['torch_cluster/kernel']
    define_macros += [('WITH_CUDA', None)]
    extra_objects += ['torch_cluster/build/kernel.so']
    with_cuda = True

ffi = create_extension(
    name='torch_cluster._ext.ffi',
    package=True,
    headers=headers,
    sources=sources,
    include_dirs=include_dirs,
    define_macros=define_macros,
    extra_objects=extra_objects,
    with_cuda=with_cuda,
    relative_to=__file__)

if __name__ == '__main__':
    ffi.build()