gatv2.py 2.26 KB
Newer Older
Shaked Brody's avatar
Shaked Brody committed
1
2
3
4
5
6
7
8
9
10
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/pdf/2105.14491.pdf
Author's code: https://github.com/tech-srl/how_attentive_are_gats
"""

import torch
import torch.nn as nn
11

Shaked Brody's avatar
Shaked Brody committed
12
from dgl.nn import GATv2Conv
13

Shaked Brody's avatar
Shaked Brody committed
14
15

class GATv2(nn.Module):
16
17
18
19
20
21
22
23
24
25
26
27
28
    def __init__(
        self,
        num_layers,
        in_dim,
        num_hidden,
        num_classes,
        heads,
        activation,
        feat_drop,
        attn_drop,
        negative_slope,
        residual,
    ):
Shaked Brody's avatar
Shaked Brody committed
29
30
31
32
33
        super(GATv2, self).__init__()
        self.num_layers = num_layers
        self.gatv2_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        self.gatv2_layers.append(
            GATv2Conv(
                in_dim,
                num_hidden,
                heads[0],
                feat_drop,
                attn_drop,
                negative_slope,
                False,
                self.activation,
                bias=False,
                share_weights=True,
            )
        )
Shaked Brody's avatar
Shaked Brody committed
48
49
50
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
51
52
53
54
55
56
57
58
59
60
61
62
63
64
            self.gatv2_layers.append(
                GATv2Conv(
                    num_hidden * heads[l - 1],
                    num_hidden,
                    heads[l],
                    feat_drop,
                    attn_drop,
                    negative_slope,
                    residual,
                    self.activation,
                    bias=False,
                    share_weights=True,
                )
            )
Shaked Brody's avatar
Shaked Brody committed
65
        # output projection
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.gatv2_layers.append(
            GATv2Conv(
                num_hidden * heads[-2],
                num_classes,
                heads[-1],
                feat_drop,
                attn_drop,
                negative_slope,
                residual,
                None,
                bias=False,
                share_weights=True,
            )
        )
Shaked Brody's avatar
Shaked Brody committed
80
81
82
83

    def forward(self, g, inputs):
        h = inputs
        for l in range(self.num_layers):
Mufei Li's avatar
Mufei Li committed
84
            h = self.gatv2_layers[l](g, h).flatten(1)
Shaked Brody's avatar
Shaked Brody committed
85
        # output projection
Mufei Li's avatar
Mufei Li committed
86
        logits = self.gatv2_layers[-1](g, h).mean(1)
Shaked Brody's avatar
Shaked Brody committed
87
        return logits