basis.py 1.52 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
4
from .utils.ffi import basis_forward as ffi_basis_forward
rusty1s's avatar
rusty1s committed
5
from .utils.ffi import basis_backward as ffi_basis_backward
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14
15


def basis_forward(degree, pseudo, kernel_size, is_open_spline):
    pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
    num_nodes, S = pseudo.size(0), (degree + 1)**kernel_size.size(0)
    basis = pseudo.new(num_nodes, S)
    weight_index = kernel_size.new(num_nodes, S)
    ffi_basis_forward(degree, basis, weight_index, pseudo, kernel_size,
                      is_open_spline)
    return basis, weight_index
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


def basis_backward(degree, grad_basis, pseudo, kernel_size, is_open_spline):
    grad_pseudo = pseudo.new(pseudo.size())
    ffi_basis_backward(degree, grad_pseudo, pseudo, kernel_size,
                       is_open_spline)


class Basis(Function):
    def __init__(self, degree, kernel_size, is_open_spline):
        super(Basis, self).__init__()
        self.degree = degree
        self.kernel_size = kernel_size
        self.is_open_spline = is_open_spline

    def forward(self, pseudo):
        self.save_for_backawrd(pseudo)
        return basis_forward(self.degree, pseudo, self.kernel_size,
                             self.is_open_spline)

    def backward(self, grad_basis, grad_weight_index):
        pass


def basis(degree, pseudo, kernel_size, is_open_spline):
    if torch.is_tensor(pseudo):
        return basis_forward(degree, pseudo, kernel_size, is_open_spline)
    else:
        return Basis(degree, kernel_size, is_open_spline)(pseudo)