weighting.py 1.78 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
3
4
from .utils.ffi import fw_weighting, bw_weighting_src
from .utils.ffi import bw_weighting_weight, bw_weighting_basis
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
9
def fw(src, weight, basis, weight_index):
    output = src.new_empty((src.size(0), weight.size(2)))
    fw_weighting(output, src, weight, basis, weight_index)
rusty1s's avatar
rusty1s committed
10
11
12
    return output


rusty1s's avatar
rusty1s committed
13
14
15
def bw_src(grad_output, weight, basis, weight_index):
    grad_src = grad_output.new_empty((grad_output.size(0), weight.size(1)))
    bw_weighting_src(grad_src, grad_output, weight, basis, weight_index)
rusty1s's avatar
rusty1s committed
16
17
18
    return grad_src


rusty1s's avatar
rusty1s committed
19
20
21
def bw_weight(grad_output, src, basis, weight_index, K):
    grad_weight = src.new_empty((K, src.size(1), grad_output.size(1)))
    bw_weighting_weight(grad_weight, grad_output, src, basis, weight_index)
rusty1s's avatar
rusty1s committed
22
23
24
    return grad_weight


rusty1s's avatar
rusty1s committed
25
26
27
def bw_basis(grad_output, src, weight, weight_index):
    grad_basis = src.new_empty(weight_index.size())
    bw_weighting_basis(grad_basis, grad_output, src, weight, weight_index)
rusty1s's avatar
rusty1s committed
28
29
30
31
    return grad_basis


class SplineWeighting(Function):
rusty1s's avatar
rusty1s committed
32
33
34
35
    @staticmethod
    def forward(ctx, src, weight, basis, weight_index):
        ctx.save_for_backward(src, weight, basis, weight_index)
        return fw(src, weight, basis, weight_index)
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
38
    @staticmethod
    def backward(ctx, grad_output):  # pragma: no cover
rusty1s's avatar
rusty1s committed
39
        grad_src = grad_weight = grad_basis = None
rusty1s's avatar
rusty1s committed
40
        src, weight, basis, weight_index = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
43
44
45
        if ctx.needs_input_grad[0]:
            grad_src = bw_src(grad_output, weight, basis, weight_index)

        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
46
            K = weight.size(0)
rusty1s's avatar
rusty1s committed
47
            grad_weight = bw_weight(grad_output, src, basis, weight_index, K)
rusty1s's avatar
rusty1s committed
48

rusty1s's avatar
rusty1s committed
49
50
        if ctx.needs_input_grad[2]:
            grad_basis = bw_basis(grad_output, src, weight, weight_index)
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
        return grad_src, grad_weight, grad_basis, None