ffi.py 962 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from ..._ext import ffi


rusty1s's avatar
rusty1s committed
4
def _get_func(name, tensor):
rusty1s's avatar
rusty1s committed
5
6
7
8
    cuda = '_cuda' if tensor.is_cuda else ''
    return getattr(ffi, 'cluster_{}{}'.format(name, cuda))


rusty1s's avatar
rusty1s committed
9
def _get_typed_func(name, tensor):
rusty1s's avatar
rusty1s committed
10
11
12
    typename = type(tensor).__name__.replace('Tensor', '')
    cuda = 'cuda_' if tensor.is_cuda else ''
    return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
rusty1s's avatar
rusty1s committed
13
14


rusty1s's avatar
bugfix  
rusty1s committed
15
16
def ffi_serial(row, col, degree, weight=None):
    output = row.new(degree.size(0)).fill_(-1)
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
    if weight is None:
        func = _get_func('serial', row)
        func(output, row, col, degree)
        return output
    else:
        func = _get_typed_func('serial', weight)
        func(output, row, col, degree, weight)
        return output


rusty1s's avatar
rusty1s committed
27
28
29
def ffi_grid(position, size, count):
    C = count.prod()
    output = count.new(position.size(0), 1)
rusty1s's avatar
rusty1s committed
30
31
    func = _get_typed_func('grid', position)
    func(C, output, position, size, count)
rusty1s's avatar
rusty1s committed
32
    output = output.squeeze(-1)
rusty1s's avatar
rusty1s committed
33
    return output