basis.py 1.17 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import basis_cpu
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
if torch.cuda.is_available():
    import basis_cuda

rusty1s's avatar
rusty1s committed
7
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def get_func(name, tensor):
rusty1s's avatar
rusty1s committed
11
    module = basis_cuda if tensor.is_cuda else basis_cpu
rusty1s's avatar
rusty1s committed
12
13
14
15
    return getattr(module, name)


class SplineBasis(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
16
    @staticmethod
rusty1s's avatar
rusty1s committed
17
    def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
rusty1s's avatar
rusty1s committed
18
        ctx.save_for_backward(pseudo)
rusty1s's avatar
rusty1s committed
19
        ctx.kernel_size, ctx.is_open_spline = kernel_size, is_open_spline
rusty1s's avatar
rusty1s committed
20
        ctx.degree = degree
rusty1s's avatar
rusty1s committed
21
22
23
24
25

        op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
        basis, weight_index = op(pseudo, kernel_size, is_open_spline)

        return basis, weight_index
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
28
    @staticmethod
    def backward(ctx, grad_basis, grad_weight_index):
rusty1s's avatar
rusty1s committed
29
        pseudo, = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
30
31
        kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
        degree = ctx.degree
rusty1s's avatar
rusty1s committed
32
        grad_pseudo = None
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
35
36
            op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
            grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
        return grad_pseudo, None, None, None