""" Inductive Representation Learning on Large Graphs Paper: http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf Code: https://github.com/williamleif/graphsage-simple Simple reference implementation of GraphSAGE. """ import argparse import time import numpy as np import networkx as nx import mxnet as mx from mxnet import nd, gluon from mxnet.gluon import nn from dgl import DGLGraph from dgl.data import register_data_args, load_data from dgl.nn.mxnet.conv import SAGEConv class GraphSAGE(nn.Block): def __init__(self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout, aggregator_type): super(GraphSAGE, self).__init__() self.g = g with self.name_scope(): self.layers = nn.Sequential() # input layer self.layers.add(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) # hidden layers for i in range(n_layers - 1): self.layers.add(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) # output layer self.layers.add(SAGEConv(n_hidden, n_classes, aggregator_type, feat_drop=dropout, activation=None)) # activation None def forward(self, features): h = features for layer in self.layers: h = layer(self.g, h) return h def evaluate(model, features, labels, mask): pred = model(features).argmax(axis=1) accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() return accuracy.asscalar() def main(args): # load and preprocess dataset data = load_data(args) features = nd.array(data.features) labels = nd.array(data.labels) train_mask = nd.array(data.train_mask) val_mask = nd.array(data.val_mask) test_mask = nd.array(data.test_mask) in_feats = features.shape[1] n_classes = data.num_labels n_edges = data.graph.number_of_edges() print("""----Data statistics------' #Edges %d #Classes %d #Train samples %d #Val samples %d #Test samples %d""" % (n_edges, n_classes, train_mask.sum().asscalar(), val_mask.sum().asscalar(), test_mask.sum().asscalar())) if args.gpu < 0: ctx = mx.cpu(0) else: ctx = mx.gpu(args.gpu) print("use cuda:", args.gpu) features = features.as_in_context(ctx) labels = labels.as_in_context(ctx) train_mask = train_mask.as_in_context(ctx) val_mask = val_mask.as_in_context(ctx) test_mask = test_mask.as_in_context(ctx) # graph preprocess and calculate normalization factor g = data.graph g.remove_edges_from(nx.selfloop_edges(g)) g = DGLGraph(g) n_edges = g.number_of_edges() # create GraphSAGE model model = GraphSAGE(g, in_feats, args.n_hidden, n_classes, args.n_layers, nd.relu, args.dropout, args.aggregator_type ) model.initialize(ctx=ctx) n_train_samples = train_mask.sum().asscalar() loss_fcn = gluon.loss.SoftmaxCELoss() print(model.collect_params()) trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr, 'wd': args.weight_decay}) # initialize graph dur = [] for epoch in range(args.n_epochs): if epoch >= 3: t0 = time.time() # forward with mx.autograd.record(): pred = model(features) loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) loss = loss.sum() / n_train_samples loss.backward() trainer.step(batch_size=1) if epoch >= 3: loss.asscalar() dur.append(time.time() - t0) acc = evaluate(model, features, labels, val_mask) print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " "ETputs(KTEPS) {:.2f}". format( epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) # test set accuracy acc = evaluate(model, features, labels, test_mask) print("Test accuracy {:.2%}".format(acc)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='GraphSAGE') register_data_args(parser) parser.add_argument("--dropout", type=float, default=0.5, help="dropout probability") parser.add_argument("--gpu", type=int, default=-1, help="gpu") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") parser.add_argument("--n-epochs", type=int, default=200, help="number of training epochs") parser.add_argument("--n-hidden", type=int, default=16, help="number of hidden gcn units") parser.add_argument("--n-layers", type=int, default=1, help="number of hidden gcn layers") parser.add_argument("--weight-decay", type=float, default=5e-4, help="Weight for L2 loss") parser.add_argument("--aggregator-type", type=str, default="gcn", help="Aggregator type: mean/gcn/pool/lstm") args = parser.parse_args() print(args) main(args)