""" Graph Attention Networks in DGL using SPMV optimization. Multiple heads are also batched together for faster training. Compared with the original paper, this code does not implement early stopping. References ---------- Paper: https://arxiv.org/abs/1710.10903 Author's code: https://github.com/PetarV-/GAT Pytorch implementation: https://github.com/Diego999/pyGAT """ import argparse import numpy as np import time import torch import torch.nn as nn import torch.nn.functional as F from dgl import DGLGraph from dgl.data import register_data_args, load_data import dgl.function as fn class GraphAttention(nn.Module): def __init__(self, g, in_dim, out_dim, num_heads, feat_drop, attn_drop, alpha, residual=False): super(GraphAttention, self).__init__() self.g = g self.num_heads = num_heads self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) if feat_drop: self.feat_drop = nn.Dropout(feat_drop) else: self.feat_drop = lambda x : x if attn_drop: self.attn_drop = nn.Dropout(attn_drop) else: self.attn_drop = lambda x : x self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) nn.init.xavier_normal_(self.attn_l.data, gain=1.414) nn.init.xavier_normal_(self.attn_r.data, gain=1.414) self.leaky_relu = nn.LeakyReLU(alpha) self.residual = residual if residual: if in_dim != out_dim: self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414) else: self.res_fc = None def forward(self, inputs): # prepare h = self.feat_drop(inputs) # NxD ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' head_ft = ft.transpose(0, 1) # HxNxD' a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1 self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) # 1. compute edge attention self.g.apply_edges(self.edge_attention) # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x))) self.edge_softmax() # 2. compute the aggregated node features scaled by the dropped, # unnormalized attention values. self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) # 3. apply normalizer ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD' # 4. residual if self.residual: if self.res_fc is not None: resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' else: resval = torch.unsqueeze(h, 1) # Nx1xD' ret = resval + ret return ret def edge_attention(self, edges): # an edge UDF to compute unnormalized attention values from src and dst a = self.leaky_relu(edges.src['a1'] + edges.dst['a2']) return {'a' : a} def edge_softmax(self): # compute the max self.g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max')) # minus the max and exp self.g.apply_edges(lambda edges : {'a' : torch.exp(edges.data['a'] - edges.dst['a_max'])}) # compute dropout self.g.apply_edges(lambda edges : {'a_drop' : self.attn_drop(edges.data['a'])}) # compute normalizer self.g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z')) class GAT(nn.Module): def __init__(self, g, num_layers, in_dim, num_hidden, num_classes, heads, activation, feat_drop, attn_drop, alpha, residual): super(GAT, self).__init__() self.g = g self.num_layers = num_layers self.gat_layers = nn.ModuleList() self.activation = activation # input projection (no residual) self.gat_layers.append(GraphAttention( g, in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False)) # hidden layers for l in range(1, num_layers): # due to multi-head, the in_dim = num_hidden * num_heads self.gat_layers.append(GraphAttention( g, num_hidden * heads[l-1], num_hidden, heads[l], feat_drop, attn_drop, alpha, residual)) # output projection self.gat_layers.append(GraphAttention( g, num_hidden * heads[-2], num_classes, heads[-1], feat_drop, attn_drop, alpha, residual)) def forward(self, inputs): h = inputs for l in range(self.num_layers): h = self.gat_layers[l](h).flatten(1) h = self.activation(h) # output projection logits = self.gat_layers[-1](h).mean(1) return logits def accuracy(logits, labels): _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == labels) return correct.item() * 1.0 / len(labels) def evaluate(model, features, labels, mask): model.eval() with torch.no_grad(): logits = model(features) logits = logits[mask] labels = labels[mask] return accuracy(logits, labels) def main(args): # load and preprocess dataset data = load_data(args) features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) train_mask = torch.ByteTensor(data.train_mask) val_mask = torch.ByteTensor(data.val_mask) test_mask = torch.ByteTensor(data.test_mask) num_feats = features.shape[1] n_classes = data.num_labels n_edges = data.graph.number_of_edges() print("""----Data statistics------' #Edges %d #Classes %d #Train samples %d #Val samples %d #Test samples %d""" % (n_edges, n_classes, train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item())) if args.gpu < 0: cuda = False else: cuda = True torch.cuda.set_device(args.gpu) features = features.cuda() labels = labels.cuda() train_mask = train_mask.cuda() val_mask = val_mask.cuda() test_mask = test_mask.cuda() # create DGL graph g = DGLGraph(data.graph) n_edges = g.number_of_edges() # add self loop g.add_edges(g.nodes(), g.nodes()) # create model heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] model = GAT(g, args.num_layers, num_feats, args.num_hidden, n_classes, heads, F.elu, args.in_drop, args.attn_drop, args.alpha, args.residual) print(model) if cuda: model.cuda() loss_fcn = torch.nn.CrossEntropyLoss() # use optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # initialize graph dur = [] for epoch in range(args.epochs): model.train() if epoch >= 3: t0 = time.time() # forward logits = model(features) loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() if epoch >= 3: dur.append(time.time() - t0) train_acc = accuracy(logits[train_mask], labels[train_mask]) if args.fastmode: val_acc = accuracy(logits[val_mask], labels[val_mask]) else: val_acc = evaluate(model, features, labels, val_mask) print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), train_acc, val_acc, n_edges / np.mean(dur) / 1000)) print() acc = evaluate(model, features, labels, test_mask) print("Test Accuracy {:.4f}".format(acc)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='GAT') register_data_args(parser) parser.add_argument("--gpu", type=int, default=-1, help="which GPU to use. Set -1 to use CPU.") parser.add_argument("--epochs", type=int, default=200, help="number of training epochs") parser.add_argument("--num-heads", type=int, default=8, help="number of hidden attention heads") parser.add_argument("--num-out-heads", type=int, default=1, help="number of output attention heads") parser.add_argument("--num-layers", type=int, default=1, help="number of hidden layers") parser.add_argument("--num-hidden", type=int, default=8, help="number of hidden units") parser.add_argument("--residual", action="store_true", default=False, help="use residual connection") parser.add_argument("--in-drop", type=float, default=.6, help="input feature dropout") parser.add_argument("--attn-drop", type=float, default=.6, help="attention dropout") parser.add_argument("--lr", type=float, default=0.005, help="learning rate") parser.add_argument('--weight-decay', type=float, default=5e-4, help="weight decay") parser.add_argument('--alpha', type=float, default=0.2, help="the negative slop of leaky relu") parser.add_argument('--fastmode', action="store_true", default=False, help="skip re-evaluate the validation set") args = parser.parse_args() print(args) main(args)