ffi.py 1.77 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from .._ext import ffi

implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}


def get_func(name, is_cuda, tensor=None):
    prefix = 'THCC' if is_cuda else 'TH'
    prefix += 'Tensor' if tensor is None else type(tensor).__name__
    return getattr(ffi, '{}_{}'.format(prefix, name))


def get_degree_str(degree):
    degree = implemented_degrees.get(degree)
    assert degree is not None, (
        'No implementation found for specified B-spline degree')
    return degree


def basis_forward(degree, basis, weight_index, pseudo, kernel_size,
                  is_open_spline):
    name = '{}BasisForward'.format(get_degree_str(degree))
    func = get_func(name, basis.is_cuda, basis)
    func(basis, weight_index, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30


def basis_backward(degree, self, grad_basis, pseudo, kernel_size,
                   is_open_spline):
    name = '{}BasisBackward'.format(get_degree_str(degree))
    func = get_func(name, self.is_cuda, self)
    func(self, grad_basis, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


def weighting_forward(self, src, weight, basis, weight_index):
    func = get_func('weightingForward', self.is_cuda, self)
    func(self, src, weight, basis, weight_index)


def weighting_backward_src(self, grad_output, weight, basis, weight_index):
    func = get_func('weightingBackwardSrc', self.is_cuda, self)
    func(self, grad_output, weight, basis, weight_index)


def weighting_backward_weight(self, grad_output, src, basis, weight_index):
    func = get_func('weightingBackwardWeight', self.is_cuda, self)
    func(self, grad_output, src, basis, weight_index)


def weighting_backward_basis(self, grad_output, src, weight, weight_index):
    func = get_func('weightingBackwardBasis', self.is_cuda, self)
    func(self, grad_output, src, weight, weight_index)