appnp.py 2.88 KB
Newer Older
Aymen Waheb's avatar
Aymen Waheb committed
1
2
3
4
5
6
7
"""
APPNP implementation in DGL.
References
----------
Paper: https://arxiv.org/abs/1810.05997
Author's code: https://github.com/klicperajo/ppnp
"""
8
import torch
Aymen Waheb's avatar
Aymen Waheb committed
9
10
11
12
import torch.nn as nn
import dgl.function as fn


13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class GraphPropagation(nn.Module):
    def __init__(self,
                 g,
                 edge_drop,
                 alpha,
                 k):
        super(GraphPropagation, self).__init__()
        self.g = g
        self.alpha = alpha
        self.k = k
        if edge_drop:
            self.edge_drop = nn.Dropout(edge_drop)
        else:
            self.edge_drop = 0.

    def forward(self, h):
        self.cached_h = h
        for _ in range(self.k):
            # normalization by square root of src degree
            h = h * self.g.ndata['norm']
            self.g.ndata['h'] = h
            if self.edge_drop:
                # performing edge dropout
                ed = self.edge_drop(torch.ones((self.g.number_of_edges(), 1)))
                self.g.edata['d'] = ed
                self.g.update_all(fn.src_mul_edge(src='h', edge='d', out='m'),
                                  fn.sum(msg='m', out='h'))
            else:
                self.g.update_all(fn.copy_src(src='h', out='m'),
                                  fn.sum(msg='m', out='h'))
            h = self.g.ndata.pop('h')
            # normalization by square root of dst degree
            h = h * self.g.ndata['norm']
            # update h using teleport probability alpha
            h = h * (1 - self.alpha) + self.cached_h * self.alpha
        return h


Aymen Waheb's avatar
Aymen Waheb committed
51
52
53
54
55
56
57
class APPNP(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 hiddens,
                 n_classes,
                 activation,
58
59
                 feat_drop,
                 edge_drop,
Aymen Waheb's avatar
Aymen Waheb committed
60
61
62
63
64
65
66
67
68
69
70
71
                 alpha,
                 k):
        super(APPNP, self).__init__()
        self.layers = nn.ModuleList()
        # input layer
        self.layers.append(nn.Linear(in_feats, hiddens[0]))
        # hidden layers
        for i in range(1, len(hiddens)):
            self.layers.append(nn.Linear(hiddens[i - 1], hiddens[i]))
        # output layer
        self.layers.append(nn.Linear(hiddens[-1], n_classes))
        self.activation = activation
72
73
        if feat_drop:
            self.feat_drop = nn.Dropout(feat_drop)
Aymen Waheb's avatar
Aymen Waheb committed
74
        else:
75
76
77
            self.feat_drop = lambda x: x
        self.propagate = GraphPropagation(g, edge_drop, alpha, k)
        self.reset_parameters()
Aymen Waheb's avatar
Aymen Waheb committed
78
79
80
81
82
83
84
85

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, features):
        # prediction step
        h = features
86
        h = self.feat_drop(h)
Aymen Waheb's avatar
Aymen Waheb committed
87
88
89
        h = self.activation(self.layers[0](h))
        for layer in self.layers[1:-1]:
            h = self.activation(layer(h))
90
91
92
        h = self.layers[-1](self.feat_drop(h))
        # propagation step
        h = self.propagate(h)
Aymen Waheb's avatar
Aymen Waheb committed
93
        return h