"docs/zh/usage/output_files.md" did not exist on "54cf49dfc9a55d0959be13ec85d0702757b0e842"
weighting.py 1.3 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
2
import torch_spline_conv.weighting_cpu
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
if torch.cuda.is_available():
5
    import torch_spline_conv.weighting_cuda
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
def get_func(name, tensor):
9
10
11
12
    if tensor.is_cuda:
        return getattr(torch_spline_conv.weighting_cuda, name)
    else:
        return getattr(torch_spline_conv.weighting_cpu, name)
rusty1s's avatar
rusty1s committed
13
14


rusty1s's avatar
rusty1s committed
15
class SplineWeighting(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
16
    @staticmethod
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
    def forward(ctx, x, weight, basis, weight_index):
        ctx.weight_index = weight_index
        ctx.save_for_backward(x, weight, basis)
        op = get_func('weighting_fw', x)
        out = op(x, weight, basis, weight_index)
        return out
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
    @staticmethod
rusty1s's avatar
0.4.1  
rusty1s committed
25
    def backward(ctx, grad_out):
rusty1s's avatar
rusty1s committed
26
27
        x, weight, basis = ctx.saved_tensors
        grad_x = grad_weight = grad_basis = None
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
30
31
            op = get_func('weighting_bw_x', x)
            grad_x = op(grad_out, weight, basis, ctx.weight_index)
rusty1s's avatar
rusty1s committed
32
33

        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
34
35
36
            op = get_func('weighting_bw_w', x)
            grad_weight = op(grad_out, x, basis, ctx.weight_index,
                             weight.size(0))
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
        if ctx.needs_input_grad[2]:
rusty1s's avatar
rusty1s committed
39
40
            op = get_func('weighting_bw_b', x)
            grad_basis = op(grad_out, x, weight, ctx.weight_index)
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
        return grad_x, grad_weight, grad_basis, None