spline_conv.py 1.73 KB
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
import torch
from torch.autograd import Variable
3
import time
rusty1s's avatar
rename  
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16

from .spline import spline

from .edgewise_spline_weighting import edgewise_spline_weighting


def spline_conv(
        adj,  # Tensor
        input,  # Variable
        weight,  # Variable
        kernel_size,
        is_open_spline,
        K,
17
18
19
        forward_kernel,
        backward_kernel,
        basis_kernel,
rusty1s's avatar
rename  
rusty1s committed
20
        degree=1,
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
21
        bias=None, ):
rusty1s's avatar
rusty1s committed
22
23
24
    if input.dim() == 1:
        input = input.unsqueeze(1)

rusty1s's avatar
rename  
rusty1s committed
25
26
27
28
29
30
    values = adj._values()
    row, col = adj._indices()

    # Get features for every end vertex with shape [|E| x M_in].
    output = input[col]
    # Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
31

Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
32
33
    amount, index = spline(values, kernel_size, is_open_spline, K, degree,
                           basis_kernel)
34

Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
35
36
    output = edgewise_spline_weighting(output, weight[:-1], amount, index,
                                       forward_kernel, backward_kernel)
rusty1s's avatar
rename  
rusty1s committed
37
38
39
40
41
42
43
44

    # Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
    # [n x M_out] feature matrix.
    zero = output.data.new(adj.size(1), output.size(1)).fill_(0.0)
    zero = Variable(zero) if not torch.is_tensor(output) else zero
    r = row.view(-1, 1).expand(row.size(0), output.size(1))
    output = zero.scatter_add_(0, Variable(r), output)

rusty1s's avatar
rusty1s committed
45
    # Weighten root node features by multiplying with root weight.
rusty1s's avatar
rename  
rusty1s committed
46
47
48
49
50
51
52
53
54
55
56
57
58
    output += torch.mm(input, weight[-1])

    # Normalize output by degree.
    ones = values.new(values.size(0)).fill_(1)
    zero = values.new(output.size(0)).fill_(0)
    degree = zero.scatter_add_(0, row, ones)
    degree = torch.clamp(degree, min=1)
    output = output / Variable(degree.view(-1, 1))

    if bias is not None:
        output += bias

    return output