weighting.py 1.74 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
3
4
from .utils.ffi import fw_weighting, bw_weighting_src
from .utils.ffi import bw_weighting_weight, bw_weighting_basis
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
def fw(src, weight, basis, weight_index):
rusty1s's avatar
rename  
rusty1s committed
8
9
10
    out = src.new_empty((src.size(0), weight.size(2)))
    fw_weighting(out, src, weight, basis, weight_index)
    return out
rusty1s's avatar
rusty1s committed
11
12


rusty1s's avatar
rename  
rusty1s committed
13
14
15
def bw_src(grad_out, weight, basis, weight_index):
    grad_src = grad_out.new_empty((grad_out.size(0), weight.size(1)))
    bw_weighting_src(grad_src, grad_out, weight, basis, weight_index)
rusty1s's avatar
rusty1s committed
16
17
18
    return grad_src


rusty1s's avatar
rename  
rusty1s committed
19
20
21
def bw_weight(grad_out, src, basis, weight_index, K):
    grad_weight = src.new_empty((K, src.size(1), grad_out.size(1)))
    bw_weighting_weight(grad_weight, grad_out, src, basis, weight_index)
rusty1s's avatar
rusty1s committed
22
23
24
    return grad_weight


rusty1s's avatar
rename  
rusty1s committed
25
def bw_basis(grad_out, src, weight, weight_index):
rusty1s's avatar
rusty1s committed
26
    grad_basis = src.new_empty(weight_index.size())
rusty1s's avatar
rename  
rusty1s committed
27
    bw_weighting_basis(grad_basis, grad_out, src, weight, weight_index)
rusty1s's avatar
rusty1s committed
28
29
30
31
    return grad_basis


class SplineWeighting(Function):
rusty1s's avatar
rusty1s committed
32
33
34
35
    @staticmethod
    def forward(ctx, src, weight, basis, weight_index):
        ctx.save_for_backward(src, weight, basis, weight_index)
        return fw(src, weight, basis, weight_index)
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
    @staticmethod
rusty1s's avatar
rename  
rusty1s committed
38
    def backward(ctx, grad_out):  # pragma: no cover
rusty1s's avatar
rusty1s committed
39
        grad_src = grad_weight = grad_basis = None
rusty1s's avatar
rusty1s committed
40
        src, weight, basis, weight_index = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rename  
rusty1s committed
43
            grad_src = bw_src(grad_out, weight, basis, weight_index)
rusty1s's avatar
rusty1s committed
44
45

        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
46
            K = weight.size(0)
rusty1s's avatar
rename  
rusty1s committed
47
            grad_weight = bw_weight(grad_out, src, basis, weight_index, K)
rusty1s's avatar
rusty1s committed
48

rusty1s's avatar
rusty1s committed
49
        if ctx.needs_input_grad[2]:
rusty1s's avatar
rename  
rusty1s committed
50
            grad_basis = bw_basis(grad_out, src, weight, weight_index)
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
        return grad_src, grad_weight, grad_basis, None