weighting.py 1.21 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch
import weighting_cpu
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
if torch.cuda.is_available():
    import weighting_cuda

rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
def get_func(name, tensor):
rusty1s's avatar
rusty1s committed
9
    module = weighting_cuda if tensor.is_cuda else weighting_cpu
rusty1s's avatar
rusty1s committed
10
    return getattr(module, name)
rusty1s's avatar
rusty1s committed
11
12


rusty1s's avatar
rusty1s committed
13
class SplineWeighting(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
14
    @staticmethod
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
    def forward(ctx, x, weight, basis, weight_index):
        ctx.weight_index = weight_index
        ctx.save_for_backward(x, weight, basis)
        op = get_func('weighting_fw', x)
        out = op(x, weight, basis, weight_index)
        return out
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
    @staticmethod
rusty1s's avatar
0.4.1  
rusty1s committed
23
    def backward(ctx, grad_out):
rusty1s's avatar
rusty1s committed
24
25
        x, weight, basis = ctx.saved_tensors
        grad_x = grad_weight = grad_basis = None
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
28
29
            op = get_func('weighting_bw_x', x)
            grad_x = op(grad_out, weight, basis, ctx.weight_index)
rusty1s's avatar
rusty1s committed
30
31

        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
32
33
34
            op = get_func('weighting_bw_w', x)
            grad_weight = op(grad_out, x, basis, ctx.weight_index,
                             weight.size(0))
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
        if ctx.needs_input_grad[2]:
rusty1s's avatar
rusty1s committed
37
38
            op = get_func('weighting_bw_b', x)
            grad_basis = op(grad_out, x, weight, ctx.weight_index)
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
        return grad_x, grad_weight, grad_basis, None