gat.py 3.86 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Optional

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear, ModuleList
from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv

from .base import HistoryGNN


class GAT(HistoryGNN):
    def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
                 hidden_heads: int, out_channels: int, out_heads: int,
                 num_layers: int, residual: bool = False, dropout: float = 0.0,
                 device=None, dtype=None):
        super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads,
                                  num_layers, device, dtype)

        self.in_channels = in_channels
        self.hidden_heads = hidden_heads
        self.out_channels = out_channels
        self.out_heads = out_heads
        self.residual = residual
        self.dropout = dropout

        self.convs = ModuleList()
        for i in range(num_layers - 1):
            in_dim = in_channels if i == 0 else hidden_channels * hidden_heads
            conv = GATConv(in_dim, hidden_channels, hidden_heads, concat=True,
                           dropout=dropout, add_self_loops=False)
            self.convs.append(conv)

        conv = GATConv(hidden_channels * hidden_heads, out_channels, out_heads,
                       concat=False, dropout=dropout, add_self_loops=False)
        self.convs.append(conv)

        self.lins = ModuleList()
        if residual:
            self.lins.append(
                Linear(in_channels, hidden_channels * hidden_heads))
            self.lins.append(
                Linear(hidden_channels * hidden_heads, out_channels))

        self.reg_modules = ModuleList([self.convs, self.lins])
        self.nonreg_modules = ModuleList()

    def reset_parameters(self):
        super(GAT, self).reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x: Tensor, adj_t: SparseTensor,
                batch_size: Optional[int] = None,
                n_id: Optional[Tensor] = None) -> Tensor:

        for conv, history in zip(self.convs[:-1], self.histories):
            h = F.dropout(x, p=self.dropout, training=self.training)
            h = conv(h, adj_t)
            if self.residual:
                x = F.dropout(x, p=self.dropout, training=self.training)
                h = h + x if h.size(-1) == x.size(-1) else h + self.lins[0](x)
            x = F.elu(h)
            x = self.push_and_pull(history, x, batch_size, n_id)

        h = F.dropout(x, p=self.dropout, training=self.training)
        h = self.convs[-1](h, adj_t)
        if self.residual:
            x = F.dropout(x, p=self.dropout, training=self.training)
            h = h + self.lins[1](x)
        if batch_size is not None:
            h = h[:batch_size]
        return h

    @torch.no_grad()
    def mini_inference(self, x: Tensor, loader) -> Tensor:
        for conv, history in zip(self.convs[:-1], self.histories):
            for info in loader:
                info = info.to(self.device)
                batch_size, n_id, adj_t, e_id = info

                r = x[n_id]
                h = conv(r, adj_t)
                if self.residual:
                    if h.size(-1) == r.size(-1):
                        h = h + r
                    else:
                        h = h + self.lins[0](r)
                h = F.elu(h)
                history.push_(h[:batch_size], n_id[:batch_size])

            x = history.pull()

        out = x.new_empty(self.num_nodes, self.out_channels)
        for info in loader:
            info = info.to(self.device)
            batch_size, n_id, adj_t, e_id = info
            r = x[n_id]
            h = self.convs[-1](r, adj_t)[:batch_size]
            if self.residual:
                h = h + self.lins[1](r)
            out[n_id[:batch_size]] = h

        return out