""" Graph Attention Networks in DGL using SPMV optimization. References ---------- Paper: https://arxiv.org/pdf/2105.14491.pdf Author's code: https://github.com/tech-srl/how_attentive_are_gats """ import torch import torch.nn as nn from dgl.nn import GATv2Conv class GATv2(nn.Module): def __init__(self, num_layers, in_dim, num_hidden, num_classes, heads, activation, feat_drop, attn_drop, negative_slope, residual): super(GATv2, self).__init__() self.num_layers = num_layers self.gatv2_layers = nn.ModuleList() self.activation = activation # input projection (no residual) self.gatv2_layers.append(GATv2Conv( in_dim, num_hidden, heads[0], feat_drop, attn_drop, negative_slope, False, self.activation, bias=False, share_weights=True)) # hidden layers for l in range(1, num_layers): # due to multi-head, the in_dim = num_hidden * num_heads self.gatv2_layers.append(GATv2Conv( num_hidden * heads[l-1], num_hidden, heads[l], feat_drop, attn_drop, negative_slope, residual, self.activation, bias=False, share_weights=True)) # output projection self.gatv2_layers.append(GATv2Conv( num_hidden * heads[-2], num_classes, heads[-1], feat_drop, attn_drop, negative_slope, residual, None, bias=False, share_weights=True)) def forward(self, g, inputs): h = inputs for l in range(self.num_layers): h = self.gatv2_layers[l](h).flatten(1) # output projection logits = self.gatv2_layers[-1](h).mean(1) return logits