gat.py 1.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import torch
import torch.nn as nn
import dgl.function as fn
13
from dgl.nn import GATConv
14
15
16
17
18
19
20
21
22
23
24
25
26


class GAT(nn.Module):
    def __init__(self,
                 g,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 heads,
                 activation,
                 feat_drop,
                 attn_drop,
27
                 negative_slope,
28
29
30
31
32
33
34
                 residual):
        super(GAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
35
36
37
        self.gat_layers.append(GATConv(
            in_dim, num_hidden, heads[0],
            feat_drop, attn_drop, negative_slope, False, self.activation))
38
39
40
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
41
42
43
            self.gat_layers.append(GATConv(
                num_hidden * heads[l-1], num_hidden, heads[l],
                feat_drop, attn_drop, negative_slope, residual, self.activation))
44
        # output projection
45
46
47
        self.gat_layers.append(GATConv(
            num_hidden * heads[-2], num_classes, heads[-1],
            feat_drop, attn_drop, negative_slope, residual, None))
48
49
50
51

    def forward(self, inputs):
        h = inputs
        for l in range(self.num_layers):
52
            h = self.gat_layers[l](self.g, h).flatten(1)
53
        # output projection
54
        logits = self.gat_layers[-1](self.g, h).mean(1)
55
        return logits