ffi.py 1.58 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
from .._ext import ffi

implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}


rusty1s's avatar
cleaner  
rusty1s committed
6
7
def get_func(name, tensor):
    prefix = 'THCC' if tensor.is_cuda else 'TH'
rusty1s's avatar
cleaner  
rusty1s committed
8
    prefix += tensor.type().split('.')[-1]
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
    return getattr(ffi, '{}_{}'.format(prefix, name))


def get_degree_str(degree):
    degree = implemented_degrees.get(degree)
    assert degree is not None, (
        'No implementation found for specified B-spline degree')
    return degree


rusty1s's avatar
rusty1s committed
19
def fw_basis(degree, basis, weight_index, pseudo, kernel_size, is_open_spline):
rusty1s's avatar
rusty1s committed
20
    name = '{}BasisForward'.format(get_degree_str(degree))
rusty1s's avatar
cleaner  
rusty1s committed
21
    func = get_func(name, basis)
rusty1s's avatar
rusty1s committed
22
    func(basis, weight_index, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
23
24


rusty1s's avatar
rusty1s committed
25
def bw_basis(degree, self, grad_basis, pseudo, kernel_size, is_open_spline):
rusty1s's avatar
rusty1s committed
26
    name = '{}BasisBackward'.format(get_degree_str(degree))
rusty1s's avatar
cleaner  
rusty1s committed
27
    func = get_func(name, self)
rusty1s's avatar
rusty1s committed
28
    func(self, grad_basis, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
29
30


rusty1s's avatar
rusty1s committed
31
def fw_weighting(self, src, weight, basis, weight_index):
rusty1s's avatar
cleaner  
rusty1s committed
32
    func = get_func('weightingForward', self)
rusty1s's avatar
rusty1s committed
33
34
35
    func(self, src, weight, basis, weight_index)


rusty1s's avatar
rusty1s committed
36
def bw_weighting_src(self, grad_output, weight, basis, weight_index):
rusty1s's avatar
cleaner  
rusty1s committed
37
    func = get_func('weightingBackwardSrc', self)
rusty1s's avatar
rusty1s committed
38
39
40
    func(self, grad_output, weight, basis, weight_index)


rusty1s's avatar
rusty1s committed
41
def bw_weighting_weight(self, grad_output, src, basis, weight_index):
rusty1s's avatar
cleaner  
rusty1s committed
42
    func = get_func('weightingBackwardWeight', self)
rusty1s's avatar
rusty1s committed
43
44
45
    func(self, grad_output, src, basis, weight_index)


rusty1s's avatar
rusty1s committed
46
def bw_weighting_basis(self, grad_output, src, weight, weight_index):
rusty1s's avatar
cleaner  
rusty1s committed
47
    func = get_func('weightingBackwardBasis', self)
rusty1s's avatar
rusty1s committed
48
    func(self, grad_output, src, weight, weight_index)