ffi.py 837 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from ..._ext import ffi


rusty1s's avatar
rusty1s committed
4
def _get_func(name, tensor):
rusty1s's avatar
rusty1s committed
5
6
7
8
    cuda = '_cuda' if tensor.is_cuda else ''
    return getattr(ffi, 'cluster_{}{}'.format(name, cuda))


rusty1s's avatar
rusty1s committed
9
def _get_typed_func(name, tensor):
rusty1s's avatar
rusty1s committed
10
11
12
    typename = type(tensor).__name__.replace('Tensor', '')
    cuda = 'cuda_' if tensor.is_cuda else ''
    return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


def ffi_serial(output, row, col, degree, weight=None):
    if weight is None:
        func = _get_func('serial', row)
        func(output, row, col, degree)
        return output
    else:
        func = _get_typed_func('serial', weight)
        func(output, row, col, degree, weight)
        return output


def ffi_grid(C, output, position, size, count):
    func = _get_typed_func('grid', position)
    func(C, output, position, size, count)
    return output