train.py 5.28 KB
Newer Older
1
2
3
4
5
6
7
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Author's implementation: https://github.com/joanbruna/GNN_community
"""

GaiYu0's avatar
GaiYu0 committed
8
from __future__ import division
GaiYu0's avatar
GaiYu0 committed
9
10

import argparse
11
import time
GaiYu0's avatar
GaiYu0 committed
12
from itertools import permutations
GaiYu0's avatar
GaiYu0 committed
13

14
import gnn
HQ's avatar
HQ committed
15
import numpy as np
GaiYu0's avatar
GaiYu0 committed
16
17
18
import torch as th
import torch.nn.functional as F
import torch.optim as optim
GaiYu0's avatar
GaiYu0 committed
19
20
from torch.utils.data import DataLoader

21
from dgl.data import SBMMixtureDataset
GaiYu0's avatar
GaiYu0 committed
22
23

parser = argparse.ArgumentParser()
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
parser.add_argument("--gpu", type=int, help="GPU index", default=-1)
parser.add_argument("--lr", type=float, help="Learning rate", default=0.001)
parser.add_argument(
    "--n-communities", type=int, help="Number of communities", default=2
)
parser.add_argument(
    "--n-epochs", type=int, help="Number of epochs", default=100
)
parser.add_argument(
    "--n-features", type=int, help="Number of features", default=16
)
parser.add_argument("--n-graphs", type=int, help="Number of graphs", default=10)
parser.add_argument("--n-layers", type=int, help="Number of layers", default=30)
parser.add_argument(
    "--n-nodes", type=int, help="Number of nodes", default=10000
)
parser.add_argument("--optim", type=str, help="Optimizer", default="Adam")
parser.add_argument("--radius", type=int, help="Radius", default=3)
parser.add_argument("--verbose", action="store_true")
GaiYu0's avatar
GaiYu0 committed
44
45
args = parser.parse_args()

46
dev = th.device("cpu") if args.gpu < 0 else th.device("cuda:%d" % args.gpu)
47
48
K = args.n_communities

49
training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
50
51
52
53
54
55
training_loader = DataLoader(
    training_dataset,
    args.batch_size,
    collate_fn=training_dataset.collate_fn,
    drop_last=True,
)
56
57

ones = th.ones(args.n_nodes // K)
58
59
60
y_list = [
    th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))
]
61
62
63
64
65

feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)

66

67
68
69
70
71
72
73
74
75
def compute_overlap(z_list):
    ybar_list = [th.max(z, 1)[1] for z in z_list]
    overlap_list = []
    for y_bar in ybar_list:
        accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes
        overlap = (accuracy - 1 / K) / (1 - 1 / K)
        overlap_list.append(overlap)
    return sum(overlap_list) / len(overlap_list)

76

HQ's avatar
HQ committed
77
78
def from_np(f, *args):
    def wrap(*args):
79
80
81
        new = [
            th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args
        ]
HQ's avatar
HQ committed
82
        return f(*new)
83

HQ's avatar
HQ committed
84
85
    return wrap

86

HQ's avatar
HQ committed
87
@from_np
88
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
89
    """One step of training."""
90
91
92
93
    g = g.to(dev)
    lg = lg.to(dev)
    deg_g = deg_g.to(dev).unsqueeze(1)
    deg_lg = deg_lg.to(dev).unsqueeze(1)
HQ's avatar
HQ committed
94
    pm_pd = pm_pd.to(dev)
95
96
97
    t0 = time.time()
    z = model(g, lg, deg_g, deg_lg, pm_pd)
    t_forward = time.time() - t0
GaiYu0's avatar
GaiYu0 committed
98

99
    z_list = th.chunk(z, args.batch_size, 0)
100
101
102
103
    loss = (
        sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list)
        / args.batch_size
    )
104
105
106
107
    overlap = compute_overlap(z_list)

    optimizer.zero_grad()
    t0 = time.time()
GaiYu0's avatar
GaiYu0 committed
108
    loss.backward()
109
110
111
112
113
    t_backward = time.time() - t0
    optimizer.step()

    return loss, overlap, t_forward, t_backward

114

HQ's avatar
HQ committed
115
116
@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
117
118
119
120
    g = g.to(dev)
    lg = lg.to(dev)
    deg_g = deg_g.to(dev).unsqueeze(1)
    deg_lg = deg_lg.to(dev).unsqueeze(1)
HQ's avatar
HQ committed
121
122
123
124
125
    pm_pd = pm_pd.to(dev)

    z = model(g, lg, deg_g, deg_lg, pm_pd)

    return z
126
127


128
def test():
129
130
    p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
    q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
131
132
133
    N = 1
    overlap_list = []
    for p, q in zip(p_list, q_list):
134
        dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N)
135
136
        loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
        g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
HQ's avatar
HQ committed
137
        z = inference(g, lg, deg_g, deg_lg, pm_pd)
138
139
140
        overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
    return overlap_list

141

142
143
144
145
n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
    total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
    for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
146
147
148
        loss, overlap, t_forward, t_backward = step(
            i, j, g, lg, deg_g, deg_lg, pm_pd
        )
149
150
151
152
153
154

        total_loss += loss
        total_overlap += overlap
        s_forward += t_forward
        s_backward += t_backward

155
156
        epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
        iteration = "0" * (len(str(n_iterations)) - len(str(j)))
157
        if args.verbose:
158
159
160
161
            print(
                "[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f"
                % (epoch, i, iteration, j, loss, overlap)
            )
GaiYu0's avatar
GaiYu0 committed
162

163
    epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
164
165
166
167
    loss = total_loss / (j + 1)
    overlap = total_overlap / (j + 1)
    t_forward = s_forward / (j + 1)
    t_backward = s_backward / (j + 1)
168
169
170
171
    print(
        "[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs"
        % (epoch, i, loss, overlap, t_forward, t_backward)
    )
GaiYu0's avatar
GaiYu0 committed
172

173
    overlap_list = test()
174
175
    overlap_str = " - ".join(["%.3f" % overlap for overlap in overlap_list])
    print("[epoch %s%d]overlap: %s" % (epoch, i, overlap_str))