train.py 2.5 KB
Newer Older
GaiYu0's avatar
GaiYu0 committed
1
from __future__ import division
GaiYu0's avatar
GaiYu0 committed
2
3
4

import argparse
from itertools import permutations
GaiYu0's avatar
GaiYu0 committed
5

GaiYu0's avatar
GaiYu0 committed
6
7
8
9
import networkx as nx
import torch as th
import torch.nn.functional as F
import torch.optim as optim
GaiYu0's avatar
GaiYu0 committed
10
11
12
13
from torch.utils.data import DataLoader

import dgl
from dgl.data import SBMMixture
GaiYu0's avatar
GaiYu0 committed
14
import gnn
GaiYu0's avatar
GaiYu0 committed
15
import utils
GaiYu0's avatar
GaiYu0 committed
16
17

parser = argparse.ArgumentParser()
GaiYu0's avatar
GaiYu0 committed
18
parser.add_argument('--batch-size', type=int,
GaiYu0's avatar
GaiYu0 committed
19
                    help='Batch size', default=1)
GaiYu0's avatar
GaiYu0 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
parser.add_argument('--gpu', type=int,
                    help='GPU', default=-1)
parser.add_argument('--n-communities', type=int,
                    help='Number of communities', default=2)
parser.add_argument('--n-features', type=int,
                    help='Number of features per layer', default=2)
parser.add_argument('--n-graphs', type=int,
                    help='Number of graphs', default=6000)
parser.add_argument('--n-iterations', type=int,
                    help='Number of iterations', default=10000)
parser.add_argument('--n-layers', type=int,
                    help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int,
                    help='Number of nodes', default=1000)
parser.add_argument('--model-path', type=str,
                    help='Path to the checkpoint of model', default='model')
parser.add_argument('--radius', type=int,
                    help='Radius', default=3)
GaiYu0's avatar
GaiYu0 committed
38
39
40
41
args = parser.parse_args()

dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)

GaiYu0's avatar
GaiYu0 committed
42
43
44
dataset = SBMMixture(args.n_graphs, args.n_nodes, args.n_communities)
loader = utils.cycle(DataLoader(dataset, args.batch_size,
                     shuffle=True, collate_fn=dataset.collate_fn, drop_last=True))
GaiYu0's avatar
GaiYu0 committed
45

GaiYu0's avatar
GaiYu0 committed
46
47
48
ones = th.ones(args.n_nodes // args.n_communities)
y_list = [th.cat([th.cat([x * ones for x in p])] * args.batch_size).long().to(dev)
      for p in permutations(range(args.n_communities))]
GaiYu0's avatar
GaiYu0 committed
49

GaiYu0's avatar
GaiYu0 committed
50
51
feats = [1] + [args.n_features] * args.n_layers + [args.n_communities]
model = gnn.GNN(feats, args.radius, args.n_communities).to(dev)
GaiYu0's avatar
GaiYu0 committed
52
53
54
opt = optim.Adamax(model.parameters(), lr=0.04)

for i in range(args.n_iterations):
GaiYu0's avatar
GaiYu0 committed
55
56
57
58
59
60
    g, lg, deg_g, deg_lg, eid2nid = next(loader)
    deg_g = deg_g.to(dev)
    deg_lg = deg_lg.to(dev)
    eid2nid = eid2nid.to(dev)
    y_bar = model(g, lg, deg_g, deg_lg, eid2nid)
    loss = min(F.cross_entropy(y_bar, y) for y in y_list)
GaiYu0's avatar
GaiYu0 committed
61
62
63
64
    opt.zero_grad()
    loss.backward()
    opt.step()

GaiYu0's avatar
GaiYu0 committed
65
66
67
68
    placeholder = '0' * (len(str(args.n_iterations)) - len(str(i)))
    print('[iteration %s%d]loss %f' % (placeholder, i, loss))

th.save(model.state_dict(), args.model_path)