basis.py 1.26 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import Tuple
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
5


rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
@torch.jit.script
def spline_basis(pseudo: torch.Tensor, kernel_size: torch.Tensor,
                 is_open_spline: torch.Tensor,
                 degree: int) -> Tuple[torch.Tensor, torch.Tensor]:
    return torch.ops.torch_spline_conv.spline_basis(pseudo, kernel_size,
                                                    is_open_spline, degree)
rusty1s's avatar
rusty1s committed
12
13


rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
# class SplineBasis(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
#         ctx.save_for_backward(pseudo)
#         ctx.kernel_size = kernel_size
#         ctx.is_open_spline = is_open_spline
#         ctx.degree = degree
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
#         op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
#         basis, weight_index = op(pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
#         return basis, weight_index
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
#     @staticmethod
#     def backward(ctx, grad_basis, grad_weight_index):
#         pseudo, = ctx.saved_tensors
#         kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
#         degree = ctx.degree
#         grad_pseudo = None
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
#         if ctx.needs_input_grad[0]:
#             grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
#         return grad_pseudo, None, None, None