basis.py 1.27 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
2
import torch_spline_conv.basis_cpu
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
if torch.cuda.is_available():
5
    import torch_spline_conv.basis_cuda
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def get_func(name, tensor):
11
12
13
14
    if tensor.is_cuda:
        return getattr(torch_spline_conv.basis_cuda, name)
    else:
        return getattr(torch_spline_conv.basis_cpu, name)
rusty1s's avatar
rusty1s committed
15
16
17


class SplineBasis(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
18
    @staticmethod
rusty1s's avatar
rusty1s committed
19
    def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
rusty1s's avatar
rusty1s committed
20
        ctx.save_for_backward(pseudo)
rusty1s's avatar
rusty1s committed
21
22
        ctx.kernel_size = kernel_size
        ctx.is_open_spline = is_open_spline
rusty1s's avatar
rusty1s committed
23
        ctx.degree = degree
rusty1s's avatar
rusty1s committed
24
25
26
27
28

        op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
        basis, weight_index = op(pseudo, kernel_size, is_open_spline)

        return basis, weight_index
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
31
    @staticmethod
    def backward(ctx, grad_basis, grad_weight_index):
rusty1s's avatar
rusty1s committed
32
        pseudo, = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
33
34
        kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
        degree = ctx.degree
rusty1s's avatar
rusty1s committed
35
        grad_pseudo = None
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
38
39
            op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
            grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
        return grad_pseudo, None, None, None