gat.py 3.67 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
from typing import Optional

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear, ModuleList
from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv

rusty1s's avatar
rusty1s committed
10
from torch_geometric_autoscale.models import ScalableGNN
rusty1s's avatar
rusty1s committed
11
12


rusty1s's avatar
rusty1s committed
13
class GAT(ScalableGNN):
rusty1s's avatar
rusty1s committed
14
15
16
    def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
                 hidden_heads: int, out_channels: int, out_heads: int,
                 num_layers: int, residual: bool = False, dropout: float = 0.0,
rusty1s's avatar
rusty1s committed
17
18
                 pool_size: Optional[int] = None,
                 buffer_size: Optional[int] = None, device=None):
rusty1s's avatar
rusty1s committed
19
        super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads,
rusty1s's avatar
rusty1s committed
20
                                  num_layers, pool_size, buffer_size, device)
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

        self.in_channels = in_channels
        self.hidden_heads = hidden_heads
        self.out_channels = out_channels
        self.out_heads = out_heads
        self.residual = residual
        self.dropout = dropout

        self.convs = ModuleList()
        for i in range(num_layers - 1):
            in_dim = in_channels if i == 0 else hidden_channels * hidden_heads
            conv = GATConv(in_dim, hidden_channels, hidden_heads, concat=True,
                           dropout=dropout, add_self_loops=False)
            self.convs.append(conv)

        conv = GATConv(hidden_channels * hidden_heads, out_channels, out_heads,
                       concat=False, dropout=dropout, add_self_loops=False)
        self.convs.append(conv)

        self.lins = ModuleList()
        if residual:
            self.lins.append(
                Linear(in_channels, hidden_channels * hidden_heads))
            self.lins.append(
                Linear(hidden_channels * hidden_heads, out_channels))

        self.reg_modules = ModuleList([self.convs, self.lins])
        self.nonreg_modules = ModuleList()

    def reset_parameters(self):
        super(GAT, self).reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x: Tensor, adj_t: SparseTensor,
                batch_size: Optional[int] = None,
rusty1s's avatar
rusty1s committed
59
60
                n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
                count: Optional[Tensor] = None) -> Tensor:
rusty1s's avatar
rusty1s committed
61
62
63

        for conv, history in zip(self.convs[:-1], self.histories):
            h = F.dropout(x, p=self.dropout, training=self.training)
rusty1s's avatar
rusty1s committed
64
            h = conv((h, h[:adj_t.size(0)]), adj_t)
rusty1s's avatar
rusty1s committed
65
66
            if self.residual:
                x = F.dropout(x, p=self.dropout, training=self.training)
rusty1s's avatar
rusty1s committed
67
                h += x if h.size(-1) == x.size(-1) else self.lins[0](x)
rusty1s's avatar
rusty1s committed
68
            x = F.elu(h)
rusty1s's avatar
rusty1s committed
69
            x = self.push_and_pull(history, x, batch_size, n_id, offset, count)
rusty1s's avatar
rusty1s committed
70
71

        h = F.dropout(x, p=self.dropout, training=self.training)
rusty1s's avatar
rusty1s committed
72
        h = self.convs[-1]((h, h[:adj_t.size(0)]), adj_t)
rusty1s's avatar
rusty1s committed
73
74
        if self.residual:
            x = F.dropout(x, p=self.dropout, training=self.training)
rusty1s's avatar
rusty1s committed
75
            h += self.lins[1](x)
rusty1s's avatar
rusty1s committed
76
77
78
        return h

    @torch.no_grad()
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    def forward_layer(self, layer, x, adj_t, state):
        h = F.dropout(x, p=self.dropout, training=self.training)
        h = self.convs[layer]((h, h[:adj_t.size(0)]), adj_t)

        if layer == 0:
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lins[0](x)

        if layer == self.num_layers - 1:
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lins[1](x)

        if self.residual:
            x = F.dropout(x, p=self.dropout, training=self.training)
            h += x

        if layer < self.num_layers - 1:
            h = h.elu()
rusty1s's avatar
rusty1s committed
97

rusty1s's avatar
rusty1s committed
98
        return h