""" Supervised Community Detection with Hierarchical Graph Neural Networks https://arxiv.org/abs/1705.08415 Author's implementation: https://github.com/joanbruna/GNN_community """ from __future__ import division import time import argparse from itertools import permutations import numpy as np import torch as th import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from dgl.data import SBMMixtureDataset import gnn parser = argparse.ArgumentParser() parser.add_argument('--batch-size', type=int, help='Batch size', default=1) parser.add_argument('--gpu', type=int, help='GPU index', default=-1) parser.add_argument('--lr', type=float, help='Learning rate', default=0.001) parser.add_argument('--n-communities', type=int, help='Number of communities', default=2) parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100) parser.add_argument('--n-features', type=int, help='Number of features', default=16) parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10) parser.add_argument('--n-layers', type=int, help='Number of layers', default=30) parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000) parser.add_argument('--optim', type=str, help='Optimizer', default='Adam') parser.add_argument('--radius', type=int, help='Radius', default=3) parser.add_argument('--verbose', action='store_true') args = parser.parse_args() dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu) K = args.n_communities training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K) training_loader = DataLoader(training_dataset, args.batch_size, collate_fn=training_dataset.collate_fn, drop_last=True) ones = th.ones(args.n_nodes // K) y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))] feats = [1] + [args.n_features] * args.n_layers + [K] model = gnn.GNN(feats, args.radius, K).to(dev) optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr) def compute_overlap(z_list): ybar_list = [th.max(z, 1)[1] for z in z_list] overlap_list = [] for y_bar in ybar_list: accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes overlap = (accuracy - 1 / K) / (1 - 1 / K) overlap_list.append(overlap) return sum(overlap_list) / len(overlap_list) def from_np(f, *args): def wrap(*args): new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args] return f(*new) return wrap @from_np def step(i, j, g, lg, deg_g, deg_lg, pm_pd): """ One step of training. """ g = g.to(dev) lg = lg.to(dev) deg_g = deg_g.to(dev).unsqueeze(1) deg_lg = deg_lg.to(dev).unsqueeze(1) pm_pd = pm_pd.to(dev) t0 = time.time() z = model(g, lg, deg_g, deg_lg, pm_pd) t_forward = time.time() - t0 z_list = th.chunk(z, args.batch_size, 0) loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size overlap = compute_overlap(z_list) optimizer.zero_grad() t0 = time.time() loss.backward() t_backward = time.time() - t0 optimizer.step() return loss, overlap, t_forward, t_backward @from_np def inference(g, lg, deg_g, deg_lg, pm_pd): g = g.to(dev) lg = lg.to(dev) deg_g = deg_g.to(dev).unsqueeze(1) deg_lg = deg_lg.to(dev).unsqueeze(1) pm_pd = pm_pd.to(dev) z = model(g, lg, deg_g, deg_lg, pm_pd) return z def test(): p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0] q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6] N = 1 overlap_list = [] for p, q in zip(p_list, q_list): dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N) loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn) g, lg, deg_g, deg_lg, pm_pd = next(iter(loader)) z = inference(g, lg, deg_g, deg_lg, pm_pd) overlap_list.append(compute_overlap(th.chunk(z, N, 0))) return overlap_list n_iterations = args.n_graphs // args.batch_size for i in range(args.n_epochs): total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0 for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader): loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd) total_loss += loss total_overlap += overlap s_forward += t_forward s_backward += t_backward epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) iteration = '0' * (len(str(n_iterations)) - len(str(j))) if args.verbose: print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f' % (epoch, i, iteration, j, loss, overlap)) epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) loss = total_loss / (j + 1) overlap = total_overlap / (j + 1) t_forward = s_forward / (j + 1) t_backward = s_backward / (j + 1) print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs' % (epoch, i, loss, overlap, t_forward, t_backward)) overlap_list = test() overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list]) print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str))