train.py 1.53 KB
Newer Older
GaiYu0's avatar
GaiYu0 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
ipython3 train.py -- --gpu -1 --n-classes 2 --n-iterations 1000 --n-layers 30 --n-nodes 1000 --n-features 2 --radius 3
"""

import argparse
from itertools import permutations
import networkx as nx
import torch as th
import torch.nn.functional as F
import torch.optim as optim
import gnn
import sbm

parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int)
parser.add_argument('--gpu', type=int)
parser.add_argument('--n-classes', type=int)
parser.add_argument('--n-features', type=int)
parser.add_argument('--n-graphs', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-layers', type=int)
parser.add_argument('--n-nodes', type=int)
parser.add_argument('--radius', type=int)
args = parser.parse_args()

dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)

ssbm = sbm.SSBM(args.n_nodes, args.n_classes, 1, 1)
gg = []
for i in range(args.n_graphs):
    ssbm.generate()
    gg.append(ssbm.graph)

assert args.n_nodes % args.n_classes == 0
ones = th.ones(int(args.n_nodes / args.n_classes))
yy = [th.cat([x * ones for x in p]).long().to(dev)
      for p in permutations(range(args.n_classes))]

feats = [1] + [args.n_features] * args.n_layers + [args.n_classes]
model = gnn.GNN(g, feats, args.radius, args.n_classes).to(dev)
opt = optim.Adamax(model.parameters(), lr=0.04)

for i in range(args.n_iterations):
    y_bar = model()
    loss = min(F.cross_entropy(y_bar, y) for y in yy)
    opt.zero_grad()
    loss.backward()
    opt.step()

    print('[iteration %d]loss %f' % (i, loss))