ffi.py 527 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
from .._ext import ffi


def get_func(name, is_cuda, tensor=None):
    prefix = 'THCC' if is_cuda else 'TH'
rusty1s's avatar
rusty1s committed
6
    prefix += 'Tensor' if tensor is None else tensor.type().split('.')[-1]
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
    return getattr(ffi, '{}_{}'.format(prefix, name))


def graclus(self, row, col, weight=None):
    func = get_func('graclus', self.is_cuda, weight)
    func(self, row, col) if weight is None else func(self, row, col, weight)


def grid(self, pos, size, count):
    func = get_func('grid', self.is_cuda, pos)
    func(self, pos, size, count)