weighting.py 1.18 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch
import weighting_cpu
rusty1s's avatar
rusty1s committed
3
4


rusty1s's avatar
rusty1s committed
5
6
7
8
def get_func(name, tensor):
    # module = weighting_cuda if tensor.is_cuda else weighting_cpu
    module = weighting_cpu
    return getattr(module, name)
rusty1s's avatar
rusty1s committed
9
10


rusty1s's avatar
rusty1s committed
11
class SplineWeighting(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
12
    @staticmethod
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
    def forward(ctx, x, weight, basis, weight_index):
        ctx.weight_index = weight_index
        ctx.save_for_backward(x, weight, basis)
        op = get_func('weighting_fw', x)
        out = op(x, weight, basis, weight_index)
        return out
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
    @staticmethod
rusty1s's avatar
0.4.1  
rusty1s committed
21
    def backward(ctx, grad_out):
rusty1s's avatar
rusty1s committed
22
23
        x, weight, basis = ctx.saved_tensors
        grad_x = grad_weight = grad_basis = None
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
26
27
            op = get_func('weighting_bw_x', x)
            grad_x = op(grad_out, weight, basis, ctx.weight_index)
rusty1s's avatar
rusty1s committed
28
29

        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
30
31
32
            op = get_func('weighting_bw_w', x)
            grad_weight = op(grad_out, x, basis, ctx.weight_index,
                             weight.size(0))
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
        if ctx.needs_input_grad[2]:
rusty1s's avatar
rusty1s committed
35
36
            op = get_func('weighting_bw_b', x)
            grad_basis = op(grad_out, x, weight, ctx.weight_index)
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
        return grad_x, grad_weight, grad_basis, None