train.py 4.72 KB
Newer Older
1
2
3
4
5
6
7
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Author's implementation: https://github.com/joanbruna/GNN_community
"""

GaiYu0's avatar
GaiYu0 committed
8
from __future__ import division
9
import time
GaiYu0's avatar
GaiYu0 committed
10
11
12

import argparse
from itertools import permutations
GaiYu0's avatar
GaiYu0 committed
13

GaiYu0's avatar
GaiYu0 committed
14
15
16
import torch as th
import torch.nn.functional as F
import torch.optim as optim
GaiYu0's avatar
GaiYu0 committed
17
18
19
from torch.utils.data import DataLoader

from dgl.data import SBMMixture
GaiYu0's avatar
GaiYu0 committed
20
21
22
import gnn

parser = argparse.ArgumentParser()
23
24
25
26
27
28
29
30
31
32
33
34
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2)
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100)
parser.add_argument('--n-features', type=int, help='Number of features', default=16)
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10)
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000)
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam')
parser.add_argument('--radius', type=int, help='Radius', default=3)
parser.add_argument('--verbose', action='store_true')
GaiYu0's avatar
GaiYu0 committed
35
36
37
args = parser.parse_args()

dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
K = args.n_communities

training_dataset = SBMMixture(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size,
                             collate_fn=training_dataset.collate_fn, drop_last=True)

ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))]

feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)

def compute_overlap(z_list):
    ybar_list = [th.max(z, 1)[1] for z in z_list]
    overlap_list = []
    for y_bar in ybar_list:
        accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes
        overlap = (accuracy - 1 / K) / (1 - 1 / K)
        overlap_list.append(overlap)
    return sum(overlap_list) / len(overlap_list)

def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
    """ One step of training. """
    t0 = time.time()
    z = model(g, lg, deg_g, deg_lg, pm_pd)
    t_forward = time.time() - t0
GaiYu0's avatar
GaiYu0 committed
65

66
67
68
69
70
71
    z_list = th.chunk(z, args.batch_size, 0)
    loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
    overlap = compute_overlap(z_list)

    optimizer.zero_grad()
    t0 = time.time()
GaiYu0's avatar
GaiYu0 committed
72
    loss.backward()
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    t_backward = time.time() - t0
    optimizer.step()

    return loss, overlap, t_forward, t_backward

def test():
    p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
    q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
    N = 1
    overlap_list = []
    for p, q in zip(p_list, q_list):
        dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N)
        loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
        g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
        deg_g = deg_g.to(dev)
        deg_lg = deg_lg.to(dev)
        pm_pd = pm_pd.to(dev)
        z = model(g, lg, deg_g, deg_lg, pm_pd)
        overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
    return overlap_list

n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
    total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
    for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
        deg_g = deg_g.to(dev)
        deg_lg = deg_lg.to(dev)
        pm_pd = pm_pd.to(dev)
        loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd)

        total_loss += loss
        total_overlap += overlap
        s_forward += t_forward
        s_backward += t_backward

        epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
        iteration = '0' * (len(str(n_iterations)) - len(str(j)))
        if args.verbose:
            print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f'
                  % (epoch, i, iteration, j, loss, overlap))
GaiYu0's avatar
GaiYu0 committed
113

114
115
116
117
118
119
120
    epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
    loss = total_loss / (j + 1)
    overlap = total_overlap / (j + 1)
    t_forward = s_forward / (j + 1)
    t_backward = s_backward / (j + 1)
    print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs'
          % (epoch, i, loss, overlap, t_forward, t_backward))
GaiYu0's avatar
GaiYu0 committed
121

122
123
124
    overlap_list = test()
    overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list])
    print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str))