train.py 5 KB
Newer Older
1
2
3
4
5
6
7
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Author's implementation: https://github.com/joanbruna/GNN_community
"""

GaiYu0's avatar
GaiYu0 committed
8
from __future__ import division
9
import time
GaiYu0's avatar
GaiYu0 committed
10
11
12

import argparse
from itertools import permutations
GaiYu0's avatar
GaiYu0 committed
13

HQ's avatar
HQ committed
14
import numpy as np
GaiYu0's avatar
GaiYu0 committed
15
16
17
import torch as th
import torch.nn.functional as F
import torch.optim as optim
GaiYu0's avatar
GaiYu0 committed
18
19
20
from torch.utils.data import DataLoader

from dgl.data import SBMMixture
GaiYu0's avatar
GaiYu0 committed
21
22
23
import gnn

parser = argparse.ArgumentParser()
24
25
26
27
28
29
30
31
32
33
34
35
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2)
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100)
parser.add_argument('--n-features', type=int, help='Number of features', default=16)
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10)
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000)
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam')
parser.add_argument('--radius', type=int, help='Radius', default=3)
parser.add_argument('--verbose', action='store_true')
GaiYu0's avatar
GaiYu0 committed
36
37
38
args = parser.parse_args()

dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
K = args.n_communities

training_dataset = SBMMixture(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size,
                             collate_fn=training_dataset.collate_fn, drop_last=True)

ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))]

feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)

def compute_overlap(z_list):
    ybar_list = [th.max(z, 1)[1] for z in z_list]
    overlap_list = []
    for y_bar in ybar_list:
        accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes
        overlap = (accuracy - 1 / K) / (1 - 1 / K)
        overlap_list.append(overlap)
    return sum(overlap_list) / len(overlap_list)

HQ's avatar
HQ committed
61
62
63
64
65
66
67
def from_np(f, *args):
    def wrap(*args):
        new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args]
        return f(*new)
    return wrap

@from_np
68
69
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
    """ One step of training. """
HQ's avatar
HQ committed
70
71
72
    deg_g = deg_g.to(dev)
    deg_lg = deg_lg.to(dev)
    pm_pd = pm_pd.to(dev)
73
74
75
    t0 = time.time()
    z = model(g, lg, deg_g, deg_lg, pm_pd)
    t_forward = time.time() - t0
GaiYu0's avatar
GaiYu0 committed
76

77
78
79
80
81
82
    z_list = th.chunk(z, args.batch_size, 0)
    loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
    overlap = compute_overlap(z_list)

    optimizer.zero_grad()
    t0 = time.time()
GaiYu0's avatar
GaiYu0 committed
83
    loss.backward()
84
85
86
87
88
    t_backward = time.time() - t0
    optimizer.step()

    return loss, overlap, t_forward, t_backward

HQ's avatar
HQ committed
89
90
91
92
93
94
95
96
97
@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
    deg_g = deg_g.to(dev)
    deg_lg = deg_lg.to(dev)
    pm_pd = pm_pd.to(dev)

    z = model(g, lg, deg_g, deg_lg, pm_pd)

    return z
98
99
100
101
102
103
104
105
106
def test():
    p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
    q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
    N = 1
    overlap_list = []
    for p, q in zip(p_list, q_list):
        dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N)
        loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
        g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
HQ's avatar
HQ committed
107
        z = inference(g, lg, deg_g, deg_lg, pm_pd)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
    return overlap_list

n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
    total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
    for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
        loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd)

        total_loss += loss
        total_overlap += overlap
        s_forward += t_forward
        s_backward += t_backward

        epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
        iteration = '0' * (len(str(n_iterations)) - len(str(j)))
        if args.verbose:
            print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f'
                  % (epoch, i, iteration, j, loss, overlap))
GaiYu0's avatar
GaiYu0 committed
127

128
129
130
131
132
133
134
    epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
    loss = total_loss / (j + 1)
    overlap = total_overlap / (j + 1)
    t_forward = s_forward / (j + 1)
    t_backward = s_backward / (j + 1)
    print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs'
          % (epoch, i, loss, overlap, t_forward, t_backward))
GaiYu0's avatar
GaiYu0 committed
135

136
137
138
    overlap_list = test()
    overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list])
    print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str))