gat.py 2.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import dgl.function as fn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
import tensorflow as tf
12
from dgl.nn import GATConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
from tensorflow.keras import layers
14
15
16


class GAT(tf.keras.Model):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    def __init__(
        self,
        g,
        num_layers,
        in_dim,
        num_hidden,
        num_classes,
        heads,
        activation,
        feat_drop,
        attn_drop,
        negative_slope,
        residual,
    ):
31
32
33
34
35
36
        super(GAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.gat_layers = []
        self.activation = activation
        # input projection (no residual)
37
38
39
40
41
42
43
44
45
46
47
48
        self.gat_layers.append(
            GATConv(
                in_dim,
                num_hidden,
                heads[0],
                feat_drop,
                attn_drop,
                negative_slope,
                False,
                self.activation,
            )
        )
49
50
51
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
52
53
54
55
56
57
58
59
60
61
62
63
            self.gat_layers.append(
                GATConv(
                    num_hidden * heads[l - 1],
                    num_hidden,
                    heads[l],
                    feat_drop,
                    attn_drop,
                    negative_slope,
                    residual,
                    self.activation,
                )
            )
64
        # output projection
65
66
67
68
69
70
71
72
73
74
75
76
        self.gat_layers.append(
            GATConv(
                num_hidden * heads[-2],
                num_classes,
                heads[-1],
                feat_drop,
                attn_drop,
                negative_slope,
                residual,
                None,
            )
        )
77
78
79
80
81
82
83
84
85

    def call(self, inputs):
        h = inputs
        for l in range(self.num_layers):
            h = self.gat_layers[l](self.g, h)
            h = tf.reshape(h, (h.shape[0], -1))
        # output projection
        logits = tf.reduce_mean(self.gat_layers[-1](self.g, h), axis=1)
        return logits