spline_weighting.py 2.33 KB
Newer Older
1
2
3
import torch
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
from .ffi import (
    spline_basis_forward,
    spline_basis_backward,
    spline_weighting_forward,
    spline_weighting_backward_input,
    spline_weighting_backward_basis,
    spline_weighting_backward_weight,
)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class SplineWeighting(Function):
    def __init__(self, kernel_size, is_open_spline, degree):
        super(SplineWeighting, self).__init__()
        self.kernel_size = kernel_size
        self.is_open_spline = is_open_spline
        self.degree = degree

    def forward(self, x, pseudo, weight):
        K = weight.size(0)
        basis, weight_index = spline_basis_forward(
            self.degree, pseudo, self.kernel_size, self.is_open_spline, K)
        output = spline_weighting_forward(x, weight, basis, weight_index)

rusty1s's avatar
rusty1s committed
27
        self.save_for_backward(x, pseudo, weight)
rusty1s's avatar
rusty1s committed
28
        self.basis, self.weight_index = basis, weight_index
29
30
31
32

        return output

    def backward(self, grad_output):  # pragma: no cover
rusty1s's avatar
rusty1s committed
33
        x, pseudo, weight = self.saved_tensors
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40
41
42
43
        basis, weight_index = self.basis, self.weight_index
        grad_input, grad_pseudo, grad_weight = None, None, None

        if self.needs_input_grad[0]:
            grad_input = spline_weighting_backward_input(
                grad_output, weight, basis, weight_index)

        if self.needs_input_grad[1]:
            grad_basis = spline_weighting_backward_basis(
                grad_output, x, weight, weight_index)
rusty1s's avatar
rusty1s committed
44
45
46
            grad_pseudo = spline_basis_backward(self.degree, grad_basis,
                                                pseudo, self.kernel_size,
                                                self.is_open_spline)
rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
53

        if self.needs_input_grad[2]:
            K = weight.size(0)
            grad_weight = spline_weighting_backward_weight(
                grad_output, x, basis, weight_index, K)

        return grad_input, grad_pseudo, grad_weight
54
55
56
57
58
59
60
61
62
63
64


def spline_weighting(x, pseudo, weight, kernel_size, is_open_spline, degree):
    if torch.is_tensor(x):
        K = weight.size(0)
        basis, weight_index = spline_basis_forward(degree, pseudo, kernel_size,
                                                   is_open_spline, K)
        return spline_weighting_forward(x, weight, basis, weight_index)
    else:
        op = SplineWeighting(kernel_size, is_open_spline, degree)
        return op(x, pseudo, weight)