""" Graph Attention Networks in DGL using SPMV optimization. Multiple heads are also batched together for faster training. Compared with the original paper, this code does not implement multiple output attention heads. References ---------- Paper: https://arxiv.org/abs/1710.10903 Author's code: https://github.com/PetarV-/GAT Pytorch implementation: https://github.com/Diego999/pyGAT """ import argparse import numpy as np import time import torch import torch.nn as nn import torch.nn.functional as F from dgl import DGLGraph from dgl.data import register_data_args, load_data import dgl.function as fn class GraphAttention(nn.Module): def __init__(self, g, in_dim, out_dim, num_heads, feat_drop, attn_drop, alpha, residual=False): super(GraphAttention, self).__init__() self.g = g self.num_heads = num_heads self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) if feat_drop: self.feat_drop = nn.Dropout(feat_drop) else: self.feat_drop = None if attn_drop: self.attn_drop = nn.Dropout(attn_drop) else: self.attn_drop = None self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) nn.init.xavier_normal_(self.attn_l.data, gain=1.414) nn.init.xavier_normal_(self.attn_r.data, gain=1.414) self.leaky_relu = nn.LeakyReLU(alpha) self.residual = residual if residual: if in_dim != out_dim: self.residual_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) else: self.residual_fc = None def forward(self, inputs): # prepare h = inputs if self.feat_drop: h = self.feat_drop(h) ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) head_ft = ft.transpose(0, 1) a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) if self.feat_drop: ft = self.feat_drop(ft) self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) # 1. compute edge attention self.g.apply_edges(self.edge_attention) # 2. compute two results, one is the node features scaled by the dropped, # unnormalized attention values. Another is the normalizer of the attention values. self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')], [fn.sum('ft', 'ft'), fn.sum('a', 'z')]) # 3. apply normalizer ret = self.g.ndata['ft'] / self.g.ndata['z'] # 4. residual if self.residual: if self.residual_fc: ret = self.residual_fc(h) + ret else: ret = h + ret return ret def edge_attention(self, edges): # an edge UDF to compute unnormalized attention values from src and dst a = self.leaky_relu(edges.src['a1'] + edges.dst['a2']) a = torch.exp(a).clamp(-10, 10) # use clamp to avoid overflow if self.attn_drop: a_drop = self.attn_drop(a) return {'a' : a, 'a_drop' : a_drop} class GAT(nn.Module): def __init__(self, g, num_layers, in_dim, num_hidden, num_classes, num_heads, activation, feat_drop, attn_drop, alpha, residual): super(GAT, self).__init__() self.g = g self.num_layers = num_layers self.gat_layers = nn.ModuleList() self.activation = activation # input projection (no residual) self.gat_layers.append(GraphAttention( g, in_dim, num_hidden, num_heads, feat_drop, attn_drop, alpha, False)) # hidden layers for l in range(num_layers - 1): # due to multi-head, the in_dim = num_hidden * num_heads self.gat_layers.append(GraphAttention( g, num_hidden * num_heads, num_hidden, num_heads, feat_drop, attn_drop, alpha, residual)) # output projection self.gat_layers.append(GraphAttention( g, num_hidden * num_heads, num_classes, 8, feat_drop, attn_drop, alpha, residual)) def forward(self, inputs): h = inputs for l in range(self.num_layers): h = self.gat_layers[l](h).flatten(1) h = self.activation(h) # output projection logits = self.gat_layers[-1](h).sum(1) return logits def accuracy(logits, labels): _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == labels) return correct.item() * 1.0 / len(labels) def evaluate(model, features, labels, mask): model.eval() with torch.no_grad(): logits = model(features) logits = logits[mask] labels = labels[mask] return accuracy(logits, labels) def main(args): # load and preprocess dataset data = load_data(args) features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) train_mask = torch.ByteTensor(data.train_mask) val_mask = torch.ByteTensor(data.val_mask) test_mask = torch.ByteTensor(data.test_mask) num_feats = features.shape[1] n_classes = data.num_labels n_edges = data.graph.number_of_edges() print("""----Data statistics------' #Edges %d #Classes %d #Train samples %d #Val samples %d #Test samples %d""" % (n_edges, n_classes, train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item())) if args.gpu < 0: cuda = False else: cuda = True torch.cuda.set_device(args.gpu) features = features.cuda() labels = labels.cuda() train_mask = train_mask.cuda() val_mask = val_mask.cuda() test_mask = test_mask.cuda() # create DGL graph g = DGLGraph(data.graph) n_edges = g.number_of_edges() # add self loop g.add_edges(g.nodes(), g.nodes()) # create model model = GAT(g, args.num_layers, num_feats, args.num_hidden, n_classes, args.num_heads, F.elu, args.in_drop, args.attn_drop, args.alpha, args.residual) print(model) if cuda: model.cuda() loss_fcn = torch.nn.CrossEntropyLoss() # use optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # initialize graph dur = [] for epoch in range(args.epochs): model.train() if epoch >= 3: t0 = time.time() # forward logits = model(features) loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() if epoch >= 3: dur.append(time.time() - t0) train_acc = accuracy(logits[train_mask], labels[train_mask]) if args.fastmode: val_acc = accuracy(logits[val_mask], labels[val_mask]) else: val_acc = evaluate(model, features, labels, val_mask) print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), train_acc, val_acc, n_edges / np.mean(dur) / 1000)) print() acc = evaluate(model, features, labels, test_mask) print("Test Accuracy {:.4f}".format(acc)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='GAT') register_data_args(parser) parser.add_argument("--gpu", type=int, default=-1, help="which GPU to use. Set -1 to use CPU.") parser.add_argument("--epochs", type=int, default=300, help="number of training epochs") parser.add_argument("--num-heads", type=int, default=8, help="number of hidden attention heads") parser.add_argument("--num-out-heads", type=int, default=1, help="number of output attention heads") parser.add_argument("--num-layers", type=int, default=1, help="number of hidden layers") parser.add_argument("--num-hidden", type=int, default=8, help="number of hidden units") parser.add_argument("--residual", action="store_true", default=False, help="use residual connection") parser.add_argument("--in-drop", type=float, default=.6, help="input feature dropout") parser.add_argument("--attn-drop", type=float, default=.6, help="attention dropout") parser.add_argument("--lr", type=float, default=0.005, help="learning rate") parser.add_argument('--weight-decay', type=float, default=5e-4, help="weight decay") parser.add_argument('--alpha', type=float, default=0.2, help="the negative slop of leaky relu") parser.add_argument('--fastmode', action="store_true", default=False, help="skip re-evaluate the validation set") args = parser.parse_args() print(args) main(args)