edgewise_spline_weighting_cpu.py 1.84 KB
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torch.autograd import Function


class EdgewiseSplineWeightingCPU(Function):
    def __init__(self, amount, index):
        super(EdgewiseSplineWeightingCPU, self).__init__()
        self.amount = amount
        self.index = index

    def forward(self, input, weight):
        self.save_for_backward(input, weight)

        _, M_in, M_out = weight.size()
        k_max = self.amount.size(1)

        output = input.new(input.size(0), M_out).fill_(0)

        for k in range(k_max):
            b = self.amount[:, k]  # [|E|]
            c = self.index[:, k]  # [|E|]

            for i in range(M_in):
                w = weight[:, i]  # [K x M_out]
                w = w[c]  # [|E| x M_out]
                f = input[:, i]  # [|E|]

                # Need to transpose twice, so we can make use of broadcasting.
                output += (f * b * w.t()).t()  # [|E| x M_out]

        return output

    def backward(self, grad_output):
        input, weight = self.saved_tensors

        K, M_in, M_out = weight.size()
        k_max = self.amount.size(1)
        num_edges = input.size(0)

        grad_input = grad_output.new(num_edges, M_in).fill_(0)
        grad_weight = grad_output.new(K, M_in, M_out).fill_(0)

        for k in range(k_max):
            b = self.amount[:, k]  # [|E|]
            c = self.index[:, k]  # [|E|]
            c_expand = c.contiguous().view(-1, 1).expand(c.size(0), M_out)

            for i in range(M_in):
                w = weight[:, i]  # [K x M_out]
                w = w[c]  # [|E| x M_out]

                f = b * torch.sum(grad_output * w, dim=1)  # [|E|]
                grad_input[:, i] += f

                f = input[:, i]  # [|E|]
                w_grad = (f * b * grad_output.t()).t()  # [|E|, M_out]
                grad_weight[:, i, :].scatter_add_(0, c_expand, w_grad)

        return grad_input, grad_weight