appnp.py 2.15 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
from typing import Optional

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Linear
from torch_sparse import SparseTensor

rusty1s's avatar
rusty1s committed
9
from torch_geometric_autoscale.models import ScalableGNN
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
rusty1s committed
12
class APPNP(ScalableGNN):
rusty1s's avatar
rusty1s committed
13
14
    def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
                 out_channels: int, num_layers: int, alpha: float,
rusty1s's avatar
rusty1s committed
15
16
                 dropout: float = 0.0, pool_size: Optional[int] = None,
                 buffer_size: Optional[int] = None, device=None):
rusty1s's avatar
rusty1s committed
17
18
        super().__init__(num_nodes, out_channels, num_layers, pool_size,
                         buffer_size, device)
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.alpha = alpha
        self.dropout = dropout

        self.lins = ModuleList()
        self.lins.append(Linear(in_channels, hidden_channels))
        self.lins.append(Linear(hidden_channels, out_channels))

        self.reg_modules = self.lins[:1]
        self.nonreg_modules = self.lins[1:]

    def reset_parameters(self):
rusty1s's avatar
rusty1s committed
33
        super().reset_parameters()
rusty1s's avatar
rusty1s committed
34
35
36
        for lin in self.lins:
            lin.reset_parameters()

rusty1s's avatar
rusty1s committed
37
    def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
rusty1s's avatar
rusty1s committed
38
39
40
41
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[0](x)
        x = x.relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
rusty1s's avatar
rusty1s committed
42
43
        x = self.lins[1](x)
        x_0 = x[:adj_t.size(0)]
rusty1s's avatar
rusty1s committed
44
45
46

        for history in self.histories:
            x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
rusty1s's avatar
rusty1s committed
47
            x = self.push_and_pull(history, x, *args)
rusty1s's avatar
rusty1s committed
48
49
50
51
52

        x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
        return x

    @torch.no_grad()
rusty1s's avatar
rusty1s committed
53
54
55
56
57
58
59
60
61
62
63
    def forward_layer(self, layer, x, adj_t, state):
        if layer == 0:
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lins[0](x)
            x = x.relu()
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x_0 = self.lins[1](x)
            state['x_0'] = x_0[:adj_t.size(0)]

        x = (1 - self.alpha) * (adj_t @ x) + self.alpha * state['x_0']
        return x