basis.py 1.38 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import basis_cpu
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
if torch.cuda.is_available():
    import basis_cuda

rusty1s's avatar
rusty1s committed
7
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def get_func(name, tensor):
rusty1s's avatar
rusty1s committed
11
    module = basis_cuda if tensor.is_cuda else basis_cpu
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
    return getattr(module, name)


def fw(pseudo, kernel_size, is_open_spline, degree):
    op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
    basis, weight_index = op(pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
18
    return basis, weight_index
rusty1s's avatar
rusty1s committed
19
20


rusty1s's avatar
rusty1s committed
21
22
23
24
def bw(grad_basis, pseudo, kernel_size, is_open_spline, degree):
    op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
    grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
    return grad_pseudo
rusty1s's avatar
rusty1s committed
25
26


rusty1s's avatar
rusty1s committed
27
class SplineBasis(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
28
    @staticmethod
rusty1s's avatar
rusty1s committed
29
    def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
rusty1s's avatar
rusty1s committed
30
31
32
33
        ctx.save_for_backward(pseudo)
        ctx.kernel_size = kernel_size
        ctx.is_open_spline = is_open_spline
        ctx.degree = degree
rusty1s's avatar
rusty1s committed
34
        return fw(pseudo, kernel_size, is_open_spline, degree)
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
    @staticmethod
    def backward(ctx, grad_basis, grad_weight_index):
rusty1s's avatar
rusty1s committed
38
        pseudo, = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
39
        grad_pseudo = None
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
42
43
            grad_pseudo = bw(grad_basis, pseudo, ctx.kernel_size,
                             ctx.is_open_spline, ctx.degree)
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
        return grad_pseudo, None, None, None