basis.py 1.27 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
4
from .utils.ffi import fw_basis, bw_basis
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
9
10
11
def fw(degree, pseudo, kernel_size, is_open_spline):
    num_edges, S = pseudo.size(0), (degree + 1)**kernel_size.size(0)
    basis = pseudo.new_empty((num_edges, S))
    weight_index = kernel_size.new_empty((num_edges, S))
    fw_basis(degree, basis, weight_index, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
12
    return basis, weight_index
rusty1s's avatar
rusty1s committed
13
14


rusty1s's avatar
rusty1s committed
15
16
17
18
def bw(degree, grad_basis, pseudo, kernel_size, is_open_spline):
    self = torch.empty_like(pseudo)
    bw_basis(degree, self, grad_basis, pseudo, kernel_size, is_open_spline)
    return self
rusty1s's avatar
rusty1s committed
19
20


rusty1s's avatar
rusty1s committed
21
class SplineBasis(Function):
rusty1s's avatar
rusty1s committed
22
    @staticmethod
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
    def forward(ctx, pseudo, kernel_size, is_open_spline, degree=1):
        ctx.save_for_backward(pseudo)
        ctx.kernel_size = kernel_size
        ctx.is_open_spline = is_open_spline
        ctx.degree = degree
        return fw(degree, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
31
    @staticmethod
    def backward(ctx, grad_basis, grad_weight_index):
rusty1s's avatar
rusty1s committed
32
        pseudo, = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
        grad_pseudo = None
rusty1s's avatar
rusty1s committed
35
36
37
        if ctx.needs_input_grad[0]:
            grad_pseudo = bw(ctx.degree, grad_basis, pseudo, ctx.kernel_size,
                             ctx.is_open_spline)
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
        return grad_pseudo, None, None, None