ffi.py 520 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from .._ext import ffi


def get_func(name, is_cuda, tensor=None):
    prefix = 'THCC' if is_cuda else 'TH'
    prefix += 'Tensor' if tensor is None else type(tensor).__name__
    return getattr(ffi, '{}_{}'.format(prefix, name))


def graclus(self, row, col, weight=None):
    func = get_func('graclus', self.is_cuda, weight)
    func(self, row, col) if weight is None else func(self, row, col, weight)


def grid(self, pos, size, count):
    func = get_func('grid', self.is_cuda, pos)
    func(self, pos, size, count)