gat.py 2.45 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
from typing import Optional

import torch
from torch import Tensor
import torch.nn.functional as F
rusty1s's avatar
rusty1s committed
6
from torch.nn import ModuleList
rusty1s's avatar
rusty1s committed
7
8
9
from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv

rusty1s's avatar
rusty1s committed
10
from torch_geometric_autoscale.models import ScalableGNN
rusty1s's avatar
rusty1s committed
11
12


rusty1s's avatar
rusty1s committed
13
class GAT(ScalableGNN):
rusty1s's avatar
rusty1s committed
14
15
    def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
                 hidden_heads: int, out_channels: int, out_heads: int,
rusty1s's avatar
rusty1s committed
16
                 num_layers: int, dropout: float = 0.0,
rusty1s's avatar
rusty1s committed
17
18
                 pool_size: Optional[int] = None,
                 buffer_size: Optional[int] = None, device=None):
rusty1s's avatar
rusty1s committed
19
20
        super().__init__(num_nodes, hidden_channels * hidden_heads, num_layers,
                         pool_size, buffer_size, device)
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

        self.in_channels = in_channels
        self.hidden_heads = hidden_heads
        self.out_channels = out_channels
        self.out_heads = out_heads
        self.dropout = dropout

        self.convs = ModuleList()
        for i in range(num_layers - 1):
            in_dim = in_channels if i == 0 else hidden_channels * hidden_heads
            conv = GATConv(in_dim, hidden_channels, hidden_heads, concat=True,
                           dropout=dropout, add_self_loops=False)
            self.convs.append(conv)

        conv = GATConv(hidden_channels * hidden_heads, out_channels, out_heads,
                       concat=False, dropout=dropout, add_self_loops=False)
        self.convs.append(conv)

rusty1s's avatar
rusty1s committed
39
        self.reg_modules = self.convs
rusty1s's avatar
rusty1s committed
40
41
42
        self.nonreg_modules = ModuleList()

    def reset_parameters(self):
rusty1s's avatar
rusty1s committed
43
        super().reset_parameters()
rusty1s's avatar
rusty1s committed
44
45
46
47
48
        for conv in self.convs:
            conv.reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()

rusty1s's avatar
rusty1s committed
49
    def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
rusty1s's avatar
rusty1s committed
50
51
        for conv, history in zip(self.convs[:-1], self.histories):
            x = F.dropout(x, p=self.dropout, training=self.training)
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
            x = conv((x, x[:adj_t.size(0)]), adj_t)
            x = F.elu(x)
            x = self.push_and_pull(history, x, *args)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t)
        return x
rusty1s's avatar
rusty1s committed
59
60

    @torch.no_grad()
rusty1s's avatar
rusty1s committed
61
    def forward_layer(self, layer, x, adj_t, state):
rusty1s's avatar
rusty1s committed
62
63
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[layer]((x, x[:adj_t.size(0)]), adj_t)
rusty1s's avatar
rusty1s committed
64
65

        if layer < self.num_layers - 1:
rusty1s's avatar
rusty1s committed
66
            x = x.elu()
rusty1s's avatar
rusty1s committed
67

rusty1s's avatar
rusty1s committed
68
        return x