"vscode:/vscode.git/clone" did not exist on "524535b5f20b2c0987549580ded8706f905a4d37"
train.py 1.93 KB
Newer Older
GaiYu0's avatar
GaiYu0 committed
1
from __future__ import division
GaiYu0's avatar
GaiYu0 committed
2
3
4

import argparse
from itertools import permutations
GaiYu0's avatar
GaiYu0 committed
5

GaiYu0's avatar
GaiYu0 committed
6
7
8
9
import networkx as nx
import torch as th
import torch.nn.functional as F
import torch.optim as optim
GaiYu0's avatar
GaiYu0 committed
10
11
12
13
from torch.utils.data import DataLoader

import dgl
from dgl.data import SBMMixture
GaiYu0's avatar
GaiYu0 committed
14
import gnn
GaiYu0's avatar
GaiYu0 committed
15
import utils
GaiYu0's avatar
GaiYu0 committed
16
17
18
19

parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int)
parser.add_argument('--gpu', type=int)
GaiYu0's avatar
GaiYu0 committed
20
parser.add_argument('--n-communities', type=int)
GaiYu0's avatar
GaiYu0 committed
21
22
23
24
25
parser.add_argument('--n-features', type=int)
parser.add_argument('--n-graphs', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-layers', type=int)
parser.add_argument('--n-nodes', type=int)
GaiYu0's avatar
GaiYu0 committed
26
parser.add_argument('--model-path', type=str)
GaiYu0's avatar
GaiYu0 committed
27
28
29
30
31
parser.add_argument('--radius', type=int)
args = parser.parse_args()

dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)

GaiYu0's avatar
GaiYu0 committed
32
33
34
dataset = SBMMixture(args.n_graphs, args.n_nodes, args.n_communities)
loader = utils.cycle(DataLoader(dataset, args.batch_size,
                     shuffle=True, collate_fn=dataset.collate_fn, drop_last=True))
GaiYu0's avatar
GaiYu0 committed
35

GaiYu0's avatar
GaiYu0 committed
36
37
38
ones = th.ones(args.n_nodes // args.n_communities)
y_list = [th.cat([th.cat([x * ones for x in p])] * args.batch_size).long().to(dev)
      for p in permutations(range(args.n_communities))]
GaiYu0's avatar
GaiYu0 committed
39

GaiYu0's avatar
GaiYu0 committed
40
41
feats = [1] + [args.n_features] * args.n_layers + [args.n_communities]
model = gnn.GNN(feats, args.radius, args.n_communities).to(dev)
GaiYu0's avatar
GaiYu0 committed
42
43
44
opt = optim.Adamax(model.parameters(), lr=0.04)

for i in range(args.n_iterations):
GaiYu0's avatar
GaiYu0 committed
45
46
47
48
49
50
    g, lg, deg_g, deg_lg, eid2nid = next(loader)
    deg_g = deg_g.to(dev)
    deg_lg = deg_lg.to(dev)
    eid2nid = eid2nid.to(dev)
    y_bar = model(g, lg, deg_g, deg_lg, eid2nid)
    loss = min(F.cross_entropy(y_bar, y) for y in y_list)
GaiYu0's avatar
GaiYu0 committed
51
52
53
54
    opt.zero_grad()
    loss.backward()
    opt.step()

GaiYu0's avatar
GaiYu0 committed
55
56
57
58
    placeholder = '0' * (len(str(args.n_iterations)) - len(str(i)))
    print('[iteration %s%d]loss %f' % (placeholder, i, loss))

th.save(model.state_dict(), args.model_path)