ffi.py 876 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from ..._ext import ffi


rusty1s's avatar
rusty1s committed
4
def _get_func(name, tensor):
rusty1s's avatar
rusty1s committed
5
6
7
8
    cuda = '_cuda' if tensor.is_cuda else ''
    return getattr(ffi, 'cluster_{}{}'.format(name, cuda))


rusty1s's avatar
rusty1s committed
9
def _get_typed_func(name, tensor):
rusty1s's avatar
rusty1s committed
10
11
12
    typename = type(tensor).__name__.replace('Tensor', '')
    cuda = 'cuda_' if tensor.is_cuda else ''
    return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
rusty1s's avatar
rusty1s committed
13
14


rusty1s's avatar
bugfix  
rusty1s committed
15
16
def ffi_serial(row, col, degree, weight=None):
    output = row.new(degree.size(0)).fill_(-1)
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    if weight is None:
        func = _get_func('serial', row)
        func(output, row, col, degree)
        return output
    else:
        func = _get_typed_func('serial', weight)
        func(output, row, col, degree, weight)
        return output


def ffi_grid(C, output, position, size, count):
    func = _get_typed_func('grid', position)
    func(C, output, position, size, count)
    return output