basis.py 1.35 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import basis_cpu
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
def get_func(name, tensor):
    # module = basis_cuda if tensor.is_cuda else basis_cpu
    module = basis_cpu
    return getattr(module, name)


def fw(pseudo, kernel_size, is_open_spline, degree):
    op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
    basis, weight_index = op(pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
16
    return basis, weight_index
rusty1s's avatar
rusty1s committed
17
18


rusty1s's avatar
rusty1s committed
19
20
21
22
def bw(grad_basis, pseudo, kernel_size, is_open_spline, degree):
    op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
    grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
    return grad_pseudo
rusty1s's avatar
rusty1s committed
23
24


rusty1s's avatar
rusty1s committed
25
class SplineBasis(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
26
    @staticmethod
rusty1s's avatar
rusty1s committed
27
    def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
rusty1s's avatar
rusty1s committed
28
29
30
31
        ctx.save_for_backward(pseudo)
        ctx.kernel_size = kernel_size
        ctx.is_open_spline = is_open_spline
        ctx.degree = degree
rusty1s's avatar
rusty1s committed
32
        return fw(pseudo, kernel_size, is_open_spline, degree)
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
    @staticmethod
    def backward(ctx, grad_basis, grad_weight_index):
rusty1s's avatar
rusty1s committed
36
        pseudo, = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
37
        grad_pseudo = None
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
40
41
            grad_pseudo = bw(grad_basis, pseudo, ctx.kernel_size,
                             ctx.is_open_spline, ctx.degree)
rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
        return grad_pseudo, None, None, None