appnp.py 1.58 KB
Newer Older
Aymen Waheb's avatar
Aymen Waheb committed
1
2
3
4
5
6
7
8
"""
APPNP implementation in DGL.
References
----------
Paper: https://arxiv.org/abs/1810.05997
Author's code: https://github.com/klicperajo/ppnp
"""
import torch.nn as nn
9
from dgl.nn.pytorch.conv import APPNPConv
10
11


Aymen Waheb's avatar
Aymen Waheb committed
12
13
14
15
16
17
18
class APPNP(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 hiddens,
                 n_classes,
                 activation,
19
20
                 feat_drop,
                 edge_drop,
Aymen Waheb's avatar
Aymen Waheb committed
21
22
23
                 alpha,
                 k):
        super(APPNP, self).__init__()
24
        self.g = g
Aymen Waheb's avatar
Aymen Waheb committed
25
26
27
28
29
30
31
32
33
        self.layers = nn.ModuleList()
        # input layer
        self.layers.append(nn.Linear(in_feats, hiddens[0]))
        # hidden layers
        for i in range(1, len(hiddens)):
            self.layers.append(nn.Linear(hiddens[i - 1], hiddens[i]))
        # output layer
        self.layers.append(nn.Linear(hiddens[-1], n_classes))
        self.activation = activation
34
35
        if feat_drop:
            self.feat_drop = nn.Dropout(feat_drop)
Aymen Waheb's avatar
Aymen Waheb committed
36
        else:
37
            self.feat_drop = lambda x: x
38
        self.propagate = APPNPConv(k, alpha, edge_drop)
39
        self.reset_parameters()
Aymen Waheb's avatar
Aymen Waheb committed
40
41
42
43
44
45
46
47

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, features):
        # prediction step
        h = features
48
        h = self.feat_drop(h)
Aymen Waheb's avatar
Aymen Waheb committed
49
50
51
        h = self.activation(self.layers[0](h))
        for layer in self.layers[1:-1]:
            h = self.activation(layer(h))
52
53
        h = self.layers[-1](self.feat_drop(h))
        # propagation step
54
        h = self.propagate(self.g, h)
Aymen Waheb's avatar
Aymen Waheb committed
55
        return h