spline_conv.py 2.45 KB
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
import torch
from torch.autograd import Variable
3
import time
rusty1s's avatar
rename  
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16

from .spline import spline

from .edgewise_spline_weighting import edgewise_spline_weighting


def spline_conv(
        adj,  # Tensor
        input,  # Variable
        weight,  # Variable
        kernel_size,
        is_open_spline,
        K,
17
18
19
        forward_kernel,
        backward_kernel,
        basis_kernel,
rusty1s's avatar
rename  
rusty1s committed
20
        degree=1,
21
        bias=None,):
22
    t_forward = time.process_time()
rusty1s's avatar
rusty1s committed
23
24
25
    if input.dim() == 1:
        input = input.unsqueeze(1)

rusty1s's avatar
rename  
rusty1s committed
26
27
28
29
    values = adj._values()
    row, col = adj._indices()

    # Get features for every end vertex with shape [|E| x M_in].
30
    t_gather = time.process_time()
rusty1s's avatar
rename  
rusty1s committed
31
    output = input[col]
32
    t_gather = time.process_time() - t_gather
rusty1s's avatar
rename  
rusty1s committed
33
    # Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
34
35

    t_basis = time.process_time()
36
    amount, index = spline(values, kernel_size, is_open_spline, K, degree, basis_kernel)
37
38
39
    t_basis = time.process_time() - t_basis

    t_conv = time.process_time()
40
    output = edgewise_spline_weighting(output, weight[:-1], amount, index, forward_kernel, backward_kernel)
41
42
    t_conv = time.process_time() - t_conv
    print('t_gather',t_gather,'time_basis:',t_basis,'time_conv:',t_conv)
rusty1s's avatar
rename  
rusty1s committed
43
44
45

    # Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
    # [n x M_out] feature matrix.
46
    t_scatter_add = time.process_time()
rusty1s's avatar
rename  
rusty1s committed
47
48
49
50
    zero = output.data.new(adj.size(1), output.size(1)).fill_(0.0)
    zero = Variable(zero) if not torch.is_tensor(output) else zero
    r = row.view(-1, 1).expand(row.size(0), output.size(1))
    output = zero.scatter_add_(0, Variable(r), output)
51
    t_scatter_add = time.process_time() - t_scatter_add
rusty1s's avatar
rename  
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
    # Weighten root node features by multiplying with root weight.
54
55

    t_root_weight = time.process_time()
rusty1s's avatar
rename  
rusty1s committed
56
    output += torch.mm(input, weight[-1])
57
    t_root_weight = time.process_time() - t_root_weight
rusty1s's avatar
rename  
rusty1s committed
58
59

    # Normalize output by degree.
60
    t_normalize = time.process_time()
rusty1s's avatar
rename  
rusty1s committed
61
62
63
64
65
    ones = values.new(values.size(0)).fill_(1)
    zero = values.new(output.size(0)).fill_(0)
    degree = zero.scatter_add_(0, row, ones)
    degree = torch.clamp(degree, min=1)
    output = output / Variable(degree.view(-1, 1))
66
67
68
    t_normalize = time.process_time() - t_normalize

    print('t_scatter_add:',t_scatter_add,'t_root_weight:',t_root_weight,'t_normalize:',t_normalize)
rusty1s's avatar
rename  
rusty1s committed
69
70
71
72

    if bias is not None:
        output += bias

73
74
    t_forward = time.process_time()- t_forward
    print('t_forward',t_forward)
rusty1s's avatar
rename  
rusty1s committed
75
    return output