import argparse import time import numpy as np import networkx as nx import tensorflow as tf from tensorflow.keras import layers import dgl from dgl.data import register_data_args from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from dgi import DGI, Classifier def evaluate(model, features, labels, mask): logits = model(features, training=False) logits = logits[mask] labels = labels[mask] indices = tf.math.argmax(logits, axis=1) acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32)) return acc.numpy().item() def main(args): # load and preprocess dataset if args.dataset == 'cora': data = CoraGraphDataset() elif args.dataset == 'citeseer': data = CiteseerGraphDataset() elif args.dataset == 'pubmed': data = PubmedGraphDataset() else: raise ValueError('Unknown dataset: {}'.format(args.dataset)) g = data[0] if args.gpu < 0: device = "/cpu:0" else: device = "/gpu:{}".format(args.gpu) g = g.to(device) with tf.device(device): features = g.ndata['feat'] labels = g.ndata['label'] train_mask = g.ndata['train_mask'] val_mask = g.ndata['val_mask'] test_mask = g.ndata['test_mask'] in_feats = features.shape[1] n_classes = data.num_labels n_edges = data.graph.number_of_edges() # add self loop if args.self_loop: g = dgl.remove_self_loop(g) g = dgl.add_self_loop(g) n_edges = g.number_of_edges() # create DGI model dgi = DGI(g, in_feats, args.n_hidden, args.n_layers, tf.keras.layers.PReLU(alpha_initializer=tf.constant_initializer(0.25)), args.dropout) dgi_optimizer = tf.keras.optimizers.Adam( learning_rate=args.dgi_lr) # train deep graph infomax cnt_wait = 0 best = 1e9 best_t = 0 dur = [] for epoch in range(args.n_dgi_epochs): if epoch >= 3: t0 = time.time() with tf.GradientTape() as tape: loss = dgi(features) # Manually Weight Decay # We found Tensorflow has a different implementation on weight decay # of Adam(W) optimizer with PyTorch. And this results in worse results. # Manually adding weights to the loss to do weight decay solves this problem. for weight in dgi.trainable_weights: loss = loss + \ args.weight_decay * tf.nn.l2_loss(weight) grads = tape.gradient(loss, dgi.trainable_weights) dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights)) if loss < best: best = loss best_t = epoch cnt_wait = 0 dgi.save_weights('best_dgi.pkl') else: cnt_wait += 1 if cnt_wait == args.patience: print('Early stopping!') break if epoch >= 3: dur.append(time.time() - t0) print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | " "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(), n_edges / np.mean(dur) / 1000)) # create classifier model classifier = Classifier(args.n_hidden, n_classes) classifier_optimizer = tf.keras.optimizers.Adam(learning_rate=args.classifier_lr) # train classifier print('Loading {}th epoch'.format(best_t)) dgi.load_weights('best_dgi.pkl') embeds = dgi.encoder(features, corrupt=False) embeds = tf.stop_gradient(embeds) dur = [] loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True) for epoch in range(args.n_classifier_epochs): if epoch >= 3: t0 = time.time() with tf.GradientTape() as tape: preds = classifier(embeds) loss = loss_fcn(labels[train_mask], preds[train_mask]) # Manually Weight Decay # We found Tensorflow has a different implementation on weight decay # of Adam(W) optimizer with PyTorch. And this results in worse results. # Manually adding weights to the loss to do weight decay solves this problem. # In original code, there's no weight decay applied in this part # link: https://github.com/PetarV-/DGI/blob/master/execute.py#L121 # for weight in classifier.trainable_weights: # loss = loss + \ # args.weight_decay * tf.nn.l2_loss(weight) grads = tape.gradient(loss, classifier.trainable_weights) classifier_optimizer.apply_gradients(zip(grads, classifier.trainable_weights)) if epoch >= 3: dur.append(time.time() - t0) acc = evaluate(classifier, embeds, labels, val_mask) print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(), acc, n_edges / np.mean(dur) / 1000)) print() acc = evaluate(classifier, embeds, labels, test_mask) print("Test Accuracy {:.4f}".format(acc)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='DGI') register_data_args(parser) parser.add_argument("--dropout", type=float, default=0., help="dropout probability") parser.add_argument("--gpu", type=int, default=-1, help="gpu") parser.add_argument("--dgi-lr", type=float, default=1e-3, help="dgi learning rate") parser.add_argument("--classifier-lr", type=float, default=1e-2, help="classifier learning rate") parser.add_argument("--n-dgi-epochs", type=int, default=300, help="number of training epochs") parser.add_argument("--n-classifier-epochs", type=int, default=300, help="number of training epochs") parser.add_argument("--n-hidden", type=int, default=512, help="number of hidden gcn units") parser.add_argument("--n-layers", type=int, default=1, help="number of hidden gcn layers") parser.add_argument("--weight-decay", type=float, default=0., help="Weight for L2 loss") parser.add_argument("--patience", type=int, default=20, help="early stop patience condition") parser.add_argument("--self-loop", action='store_true', help="graph self-loop (default=False)") parser.set_defaults(self_loop=False) args = parser.parse_args() print(args) main(args)