""" Graph Attention Networks in DGL using SPMV optimization. Multiple heads are also batched together for faster training. References ---------- Paper: https://arxiv.org/abs/1710.10903 Author's code: https://github.com/PetarV-/GAT Pytorch implementation: https://github.com/Diego999/pyGAT """ import argparse import time import mxnet as mx import networkx as nx import numpy as np from gat import GAT from mxnet import gluon from utils import EarlyStopping import dgl from dgl.data import (CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset, register_data_args) def elu(data): return mx.nd.LeakyReLU(data, act_type="elu") def evaluate(model, features, labels, mask): logits = model(features) logits = logits[mask].asnumpy().squeeze() val_labels = labels[mask].asnumpy().squeeze() max_index = np.argmax(logits, axis=1) accuracy = np.sum(np.where(max_index == val_labels, 1, 0)) / len(val_labels) return accuracy def main(args): # load and preprocess dataset if args.dataset == "cora": data = CoraGraphDataset() elif args.dataset == "citeseer": data = CiteseerGraphDataset() elif args.dataset == "pubmed": data = PubmedGraphDataset() else: raise ValueError("Unknown dataset: {}".format(args.dataset)) g = data[0] if args.gpu < 0: cuda = False ctx = mx.cpu(0) else: cuda = True ctx = mx.gpu(args.gpu) g = g.to(ctx) features = g.ndata["feat"] labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx) mask = g.ndata["train_mask"] mask = mx.nd.array(np.nonzero(mask.asnumpy())[0], ctx=ctx) val_mask = g.ndata["val_mask"] val_mask = mx.nd.array(np.nonzero(val_mask.asnumpy())[0], ctx=ctx) test_mask = g.ndata["test_mask"] test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx) in_feats = features.shape[1] n_classes = data.num_labels n_edges = data.graph.number_of_edges() g = dgl.remove_self_loop(g) g = dgl.add_self_loop(g) # create model heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] model = GAT( g, args.num_layers, in_feats, args.num_hidden, n_classes, heads, elu, args.in_drop, args.attn_drop, args.alpha, args.residual, ) if args.early_stop: stopper = EarlyStopping(patience=100) model.initialize(ctx=ctx) # use optimizer trainer = gluon.Trainer( model.collect_params(), "adam", {"learning_rate": args.lr} ) dur = [] for epoch in range(args.epochs): if epoch >= 3: t0 = time.time() # forward with mx.autograd.record(): logits = model(features) loss = mx.nd.softmax_cross_entropy( logits[mask].squeeze(), labels[mask].squeeze() ) loss.backward() trainer.step(mask.shape[0]) if epoch >= 3: dur.append(time.time() - t0) print( "Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000, ) ) val_accuracy = evaluate(model, features, labels, val_mask) print("Validation Accuracy {:.4f}".format(val_accuracy)) if args.early_stop: if stopper.step(val_accuracy, model): break print() if args.early_stop: model.load_parameters("model.param") test_accuracy = evaluate(model, features, labels, test_mask) print("Test Accuracy {:.4f}".format(test_accuracy)) if __name__ == "__main__": parser = argparse.ArgumentParser(description="GAT") register_data_args(parser) parser.add_argument( "--gpu", type=int, default=-1, help="which GPU to use. Set -1 to use CPU.", ) parser.add_argument( "--epochs", type=int, default=200, help="number of training epochs" ) parser.add_argument( "--num-heads", type=int, default=8, help="number of hidden attention heads", ) parser.add_argument( "--num-out-heads", type=int, default=1, help="number of output attention heads", ) parser.add_argument( "--num-layers", type=int, default=1, help="number of hidden layers" ) parser.add_argument( "--num-hidden", type=int, default=8, help="number of hidden units" ) parser.add_argument( "--residual", action="store_true", default=False, help="use residual connection", ) parser.add_argument( "--in-drop", type=float, default=0.6, help="input feature dropout" ) parser.add_argument( "--attn-drop", type=float, default=0.6, help="attention dropout" ) parser.add_argument("--lr", type=float, default=0.005, help="learning rate") parser.add_argument( "--weight-decay", type=float, default=5e-4, help="weight decay" ) parser.add_argument( "--alpha", type=float, default=0.2, help="the negative slop of leaky relu", ) parser.add_argument( "--early-stop", action="store_true", default=False, help="indicates whether to use early stop or not", ) args = parser.parse_args() print(args) main(args)