weighting.py 1.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
@torch.jit.script
def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
                     basis: torch.Tensor,
                     weight_index: torch.Tensor) -> torch.Tensor:
    return torch.ops.spline_conv.spline_weighting(x, weight, basis,
                                                  weight_index)
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
18
19
# class SplineWeighting(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, x, weight, basis, weight_index):
#         ctx.weight_index = weight_index
#         ctx.save_for_backward(x, weight, basis)
#         op = get_func('weighting_fw', x)
#         out = op(x, weight, basis, weight_index)
#         return out
rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
22
23
24
#     @staticmethod
#     def backward(ctx, grad_out):
#         x, weight, basis = ctx.saved_tensors
#         grad_x = grad_weight = grad_basis = None
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
28
#         if ctx.needs_input_grad[0]:
#             op = get_func('weighting_bw_x', x)
#             grad_x = op(grad_out, weight, basis, ctx.weight_index)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
31
32
33
#         if ctx.needs_input_grad[1]:
#             op = get_func('weighting_bw_w', x)
#             grad_weight = op(grad_out, x, basis, ctx.weight_index,
#                              weight.size(0))
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
36
37
#         if ctx.needs_input_grad[2]:
#             op = get_func('weighting_bw_b', x)
#             grad_basis = op(grad_out, x, weight, ctx.weight_index)
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
#         return grad_x, grad_weight, grad_basis, None