build.py 845 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from os import path as osp

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
5
6
from torch.utils.ffi import create_extension

abs_path = osp.join(osp.dirname(osp.realpath(__file__)), 'torch_scatter')
rusty1s's avatar
rusty1s committed
7
abs_path = 'torch_scatter'
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
headers = ['torch_scatter/src/cpu.h']
sources = ['torch_scatter/src/cpu.c']
includes = ['torch_scatter/src']
rusty1s's avatar
rusty1s committed
12
13
14
15
defines = []
extra_objects = []
with_cuda = False

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
if torch.cuda.is_available():
    headers += ['torch_scatter/src/cuda.h']
    sources += ['torch_scatter/src/cuda.c']
    defines += [('WITH_CUDA', None)]
    with_cuda = True

rusty1s's avatar
rusty1s committed
22
ffi = create_extension(
rusty1s's avatar
rename  
rusty1s committed
23
    name='torch_scatter._ext.ffi',
rusty1s's avatar
rusty1s committed
24
25
26
27
    package=True,
    verbose=True,
    headers=headers,
    sources=sources,
rusty1s's avatar
rusty1s committed
28
    include_dirs=includes,
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
    define_macros=defines,
    extra_objects=extra_objects,
    with_cuda=with_cuda,
    relative_to=__file__)

if __name__ == '__main__':
    ffi.build()