gat.py 1.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import mxnet.gluon.nn as nn
11
from dgl.nn.mxnet.conv import GATConv
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


class GAT(nn.Block):
    def __init__(self,
                 g,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 heads,
                 activation,
                 feat_drop,
                 attn_drop,
                 alpha,
                 residual):
        super(GAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.gat_layers = []
        self.activation = activation
        # input projection (no residual)
33
34
        self.gat_layers.append(GATConv(
            in_dim, num_hidden, heads[0],
35
36
37
38
            feat_drop, attn_drop, alpha, False))
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
39
40
            self.gat_layers.append(GATConv(
                num_hidden * heads[l-1], num_hidden, heads[l],
41
42
                feat_drop, attn_drop, alpha, residual))
        # output projection
43
44
        self.gat_layers.append(GATConv(
            num_hidden * heads[-2], num_classes, heads[-1],
45
46
47
48
49
50
51
            feat_drop, attn_drop, alpha, residual))
        for i, layer in enumerate(self.gat_layers):
            self.register_child(layer, "gat_layer_{}".format(i))

    def forward(self, inputs):
        h = inputs
        for l in range(self.num_layers):
52
            h = self.gat_layers[l](self.g, h).flatten()
53
54
            h = self.activation(h)
        # output projection
55
        logits = self.gat_layers[-1](self.g, h).mean(1)
56
        return logits