spline_conv.py 2 KB
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
import torch
from torch.autograd import Variable
3
import time
rusty1s's avatar
rename  
rusty1s committed
4

5
from .spline_conv_gpu import SplineConvGPU
rusty1s's avatar
rename  
rusty1s committed
6
7

def spline_conv(
8
9
10
11
        adj,  # Pytorch Tensor (!bp_to_adj) or Pytorch Variable (bp_to_adj)
        input,  # Pytorch Variable
        weight,  # Pytorch Variable
        kernel_size,  # Rest tensors or python variables
rusty1s's avatar
rename  
rusty1s committed
12
13
        is_open_spline,
        K,
14
15
        weighting_kernel,
        weighting_backward_kernel,
16
        basis_kernel,
17
        basis_backward_kernel=None,
rusty1s's avatar
rename  
rusty1s committed
18
        degree=1,
19
20
        bias=None,
        bp_to_adj=False):
rusty1s's avatar
fix  
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
24
    if input.dim() == 1:
        input = input.unsqueeze(1)

25
26
    values = adj['values']
    row, col = adj['indices']
rusty1s's avatar
rename  
rusty1s committed
27
28
29

    # Get features for every end vertex with shape [|E| x M_in].
    output = input[col]
30

rusty1s's avatar
rename  
rusty1s committed
31
    # Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
32

33
34
35
36
37
38
39
40
    if output.is_cuda:
        output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
                               basis_kernel, basis_backward_kernel,
                               weighting_kernel, weighting_backward_kernel,
                               bp_to_adj)(output, weight[:-1], values)
    else:
        # CPU Implementation not available
        raise NotImplementedError()
rusty1s's avatar
rename  
rusty1s committed
41
42
43

    # Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
    # [n x M_out] feature matrix.
44
    zero = output.data.new(adj['size'][1], output.size(1)).fill_(0.0)
rusty1s's avatar
rename  
rusty1s committed
45
46
47
48
    zero = Variable(zero) if not torch.is_tensor(output) else zero
    r = row.view(-1, 1).expand(row.size(0), output.size(1))
    output = zero.scatter_add_(0, Variable(r), output)

rusty1s's avatar
rusty1s committed
49
    # Weighten root node features by multiplying with root weight.
rusty1s's avatar
rename  
rusty1s committed
50
51
52
    output += torch.mm(input, weight[-1])

    # Normalize output by degree.
53
54
    ones = values.data.new(values.size(0)).fill_(1)
    zero = values.data.new(output.size(0)).fill_(0)
rusty1s's avatar
rename  
rusty1s committed
55
56
57
58
59
60
61
62
    degree = zero.scatter_add_(0, row, ones)
    degree = torch.clamp(degree, min=1)
    output = output / Variable(degree.view(-1, 1))

    if bias is not None:
        output += bias

    return output