spline_conv.py 2.38 KB
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
import torch
from torch.autograd import Variable
3
import time
rusty1s's avatar
rename  
rusty1s committed
4

5
from .spline_conv_gpu import SplineConvGPU
rusty1s's avatar
rename  
rusty1s committed
6
7

def spline_conv(
8
9
10
11
        adj,  # Pytorch Tensor (!bp_to_adj) or Pytorch Variable (bp_to_adj)
        input,  # Pytorch Variable
        weight,  # Pytorch Variable
        kernel_size,  # Rest tensors or python variables
rusty1s's avatar
rename  
rusty1s committed
12
13
        is_open_spline,
        K,
14
15
        weighting_kernel,
        weighting_backward_kernel,
16
        basis_kernel,
17
        basis_backward_kernel=None,
rusty1s's avatar
rename  
rusty1s committed
18
        degree=1,
19
        bias=None):
rusty1s's avatar
fix  
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
22
23
    if input.dim() == 1:
        input = input.unsqueeze(1)

24
25
    values = adj['values']
    row, col = adj['indices']
rusty1s's avatar
rename  
rusty1s committed
26
27
28

    # Get features for every end vertex with shape [|E| x M_in].
    output = input[col]
29

30
    bp_to_adj = False if torch.is_tensor(values) else True
rusty1s's avatar
rename  
rusty1s committed
31
    # Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
Jan Eric Lenssen's avatar
Jan Eric Lenssen committed
32

33
    if output.is_cuda:
34
35
36
37
38
39
40
41
42
43
        if bp_to_adj:
            output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
                                   basis_kernel, basis_backward_kernel,
                                   weighting_kernel, weighting_backward_kernel,
                                   bp_to_adj)(output, weight[:-1], values)
        else:
            output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
                                   basis_kernel, basis_backward_kernel,
                                   weighting_kernel, weighting_backward_kernel,
                                   bp_to_adj, values)(output, weight[:-1])
44
45
46
    else:
        # CPU Implementation not available
        raise NotImplementedError()
rusty1s's avatar
rename  
rusty1s committed
47
48
49

    # Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
    # [n x M_out] feature matrix.
50
    zero = output.data.new(adj['size'][1], output.size(1)).fill_(0.0)
rusty1s's avatar
rename  
rusty1s committed
51
52
53
54
    zero = Variable(zero) if not torch.is_tensor(output) else zero
    r = row.view(-1, 1).expand(row.size(0), output.size(1))
    output = zero.scatter_add_(0, Variable(r), output)

rusty1s's avatar
rusty1s committed
55
    # Weighten root node features by multiplying with root weight.
rusty1s's avatar
rename  
rusty1s committed
56
57
58
    output += torch.mm(input, weight[-1])

    # Normalize output by degree.
59
60
    ones = output.data.new(values.size(0)).fill_(1)
    zero = output.data.new(output.size(0)).fill_(0)
rusty1s's avatar
rename  
rusty1s committed
61
62
63
64
65
66
67
68
    degree = zero.scatter_add_(0, row, ones)
    degree = torch.clamp(degree, min=1)
    output = output / Variable(degree.view(-1, 1))

    if bias is not None:
        output += bias

    return output