train.py 5.14 KB
Newer Older
1
2
3
4
5
6
7
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Author's implementation: https://github.com/joanbruna/GNN_community
"""

GaiYu0's avatar
GaiYu0 committed
8
from __future__ import division
9
import time
GaiYu0's avatar
GaiYu0 committed
10
11
12

import argparse
from itertools import permutations
GaiYu0's avatar
GaiYu0 committed
13

HQ's avatar
HQ committed
14
import numpy as np
GaiYu0's avatar
GaiYu0 committed
15
16
17
import torch as th
import torch.nn.functional as F
import torch.optim as optim
GaiYu0's avatar
GaiYu0 committed
18
19
from torch.utils.data import DataLoader

20
from dgl.data import SBMMixtureDataset
GaiYu0's avatar
GaiYu0 committed
21
22
23
import gnn

parser = argparse.ArgumentParser()
24
25
26
27
28
29
30
31
32
33
34
35
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2)
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100)
parser.add_argument('--n-features', type=int, help='Number of features', default=16)
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10)
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000)
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam')
parser.add_argument('--radius', type=int, help='Radius', default=3)
parser.add_argument('--verbose', action='store_true')
GaiYu0's avatar
GaiYu0 committed
36
37
38
args = parser.parse_args()

dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
39
40
K = args.n_communities

41
training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
training_loader = DataLoader(training_dataset, args.batch_size,
                             collate_fn=training_dataset.collate_fn, drop_last=True)

ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))]

feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)

def compute_overlap(z_list):
    ybar_list = [th.max(z, 1)[1] for z in z_list]
    overlap_list = []
    for y_bar in ybar_list:
        accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes
        overlap = (accuracy - 1 / K) / (1 - 1 / K)
        overlap_list.append(overlap)
    return sum(overlap_list) / len(overlap_list)

HQ's avatar
HQ committed
61
62
63
64
65
66
67
def from_np(f, *args):
    def wrap(*args):
        new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args]
        return f(*new)
    return wrap

@from_np
68
69
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
    """ One step of training. """
70
71
72
73
    g = g.to(dev)
    lg = lg.to(dev)
    deg_g = deg_g.to(dev).unsqueeze(1)
    deg_lg = deg_lg.to(dev).unsqueeze(1)
HQ's avatar
HQ committed
74
    pm_pd = pm_pd.to(dev)
75
76
77
    t0 = time.time()
    z = model(g, lg, deg_g, deg_lg, pm_pd)
    t_forward = time.time() - t0
GaiYu0's avatar
GaiYu0 committed
78

79
80
81
82
83
84
    z_list = th.chunk(z, args.batch_size, 0)
    loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
    overlap = compute_overlap(z_list)

    optimizer.zero_grad()
    t0 = time.time()
GaiYu0's avatar
GaiYu0 committed
85
    loss.backward()
86
87
88
89
90
    t_backward = time.time() - t0
    optimizer.step()

    return loss, overlap, t_forward, t_backward

HQ's avatar
HQ committed
91
92
@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
93
94
95
96
    g = g.to(dev)
    lg = lg.to(dev)
    deg_g = deg_g.to(dev).unsqueeze(1)
    deg_lg = deg_lg.to(dev).unsqueeze(1)
HQ's avatar
HQ committed
97
98
99
100
101
    pm_pd = pm_pd.to(dev)

    z = model(g, lg, deg_g, deg_lg, pm_pd)

    return z
102
103
104
105
106
107
def test():
    p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
    q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
    N = 1
    overlap_list = []
    for p, q in zip(p_list, q_list):
108
        dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N)
109
110
        loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
        g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
HQ's avatar
HQ committed
111
        z = inference(g, lg, deg_g, deg_lg, pm_pd)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
    return overlap_list

n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
    total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
    for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
        loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd)

        total_loss += loss
        total_overlap += overlap
        s_forward += t_forward
        s_backward += t_backward

        epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
        iteration = '0' * (len(str(n_iterations)) - len(str(j)))
        if args.verbose:
            print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f'
                  % (epoch, i, iteration, j, loss, overlap))
GaiYu0's avatar
GaiYu0 committed
131

132
133
134
135
136
137
138
    epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
    loss = total_loss / (j + 1)
    overlap = total_overlap / (j + 1)
    t_forward = s_forward / (j + 1)
    t_backward = s_backward / (j + 1)
    print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs'
          % (epoch, i, loss, overlap, t_forward, t_backward))
GaiYu0's avatar
GaiYu0 committed
139

140
141
142
    overlap_list = test()
    overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list])
    print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str))