basis.py 1.77 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
4
5
from .utils.ffi import basis_forward as basis_fw
from .utils.ffi import basis_backward as basis_bw
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11


def basis_forward(degree, pseudo, kernel_size, is_open_spline):
    num_nodes, S = pseudo.size(0), (degree + 1)**kernel_size.size(0)
    basis = pseudo.new(num_nodes, S)
    weight_index = kernel_size.new(num_nodes, S)
rusty1s's avatar
rusty1s committed
12
    basis_fw(degree, basis, weight_index, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
13
    return basis, weight_index
rusty1s's avatar
rusty1s committed
14
15


rusty1s's avatar
rusty1s committed
16
17
def basis_backward(degree, grad_basis, pseudo, kernel_size,
                   is_open_spline):  # pragma: no cover
rusty1s's avatar
rusty1s committed
18
    grad_pseudo = pseudo.new(pseudo.size())
rusty1s's avatar
rusty1s committed
19
20
    basis_bw(degree, grad_pseudo, grad_basis, pseudo, kernel_size,
             is_open_spline)
rusty1s's avatar
rusty1s committed
21
    return grad_pseudo
rusty1s's avatar
rusty1s committed
22
23


rusty1s's avatar
rusty1s committed
24
class SplineBasis(Function):
rusty1s's avatar
rusty1s committed
25
    def __init__(self, degree, kernel_size, is_open_spline):
rusty1s's avatar
rusty1s committed
26
        super(SplineBasis, self).__init__()
rusty1s's avatar
rusty1s committed
27
28
29
30
31
        self.degree = degree
        self.kernel_size = kernel_size
        self.is_open_spline = is_open_spline

    def forward(self, pseudo):
rusty1s's avatar
rusty1s committed
32
        self.save_for_backward(pseudo)
rusty1s's avatar
rusty1s committed
33
34
35
        return basis_forward(self.degree, pseudo, self.kernel_size,
                             self.is_open_spline)

rusty1s's avatar
rusty1s committed
36
    def backward(self, grad_basis, grad_weight_index):  # pragma: no cover
rusty1s's avatar
rusty1s committed
37
        grad_pseudo = None
rusty1s's avatar
rusty1s committed
38
        pseudo, = self.saved_tensors
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
41
42
        if self.needs_input_grad[0]:
            grad_pseudo = basis_backward(self.degree, grad_basis, pseudo,
                                         self.kernel_size, self.is_open_spline)
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
45
46
47
        return grad_pseudo


def spline_basis(degree, pseudo, kernel_size, is_open_spline):
rusty1s's avatar
rusty1s committed
48
49
50
    if torch.is_tensor(pseudo):
        return basis_forward(degree, pseudo, kernel_size, is_open_spline)
    else:
rusty1s's avatar
rusty1s committed
51
        return SplineBasis(degree, kernel_size, is_open_spline)(pseudo)