train_main.py 2.71 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# The training codes of the dummy model


import os
import argparse
import dgl

import torch as th
import torch.nn as nn

from dgl import save_graphs
from models import dummy_gnn_model
from gengraph import gen_syn1, gen_syn2, gen_syn3, gen_syn4, gen_syn5
import numpy as np

def main(args):
    # load dataset
    if args.dataset == 'syn1':
        g, labels, name = gen_syn1()
    elif args.dataset == 'syn2':
        g, labels, name = gen_syn2()
    elif args.dataset == 'syn3':
        g, labels, name = gen_syn3()
    elif args.dataset == 'syn4':
        g, labels, name = gen_syn4()
    elif args.dataset == 'syn5':
        g, labels, name = gen_syn5()
    else:
        raise NotImplementedError
    
    #Transform to dgl graph. 
    graph = dgl.from_networkx(g) 
    labels = th.tensor(labels, dtype=th.long)
    graph.ndata['label'] = labels
    graph.ndata['feat'] = th.randn(graph.number_of_nodes(), args.feat_dim)
    hid_dim = th.tensor(args.hidden_dim, dtype=th.long)
    label_dict = {'hid_dim':hid_dim}

    # save graph for later use
    save_graphs(filename='./'+args.dataset+'.bin', g_list=[graph], labels=label_dict)

    num_classes = max(graph.ndata['label']).item() + 1
    n_feats = graph.ndata['feat']

    #create model
    dummy_model = dummy_gnn_model(args.feat_dim, args.hidden_dim, num_classes)
    loss_fn = nn.CrossEntropyLoss()
    optim = th.optim.Adam(dummy_model.parameters(), lr=args.lr, weight_decay=args.wd)

    # train and output
    for epoch in range(args.epochs):

        dummy_model.train()

        logits = dummy_model(graph, n_feats)
        loss = loss_fn(logits, labels)
        acc = th.sum(logits.argmax(dim=1) == labels).item() / len(labels)
        
        optim.zero_grad()
        loss.backward()
        optim.step()

        print('In Epoch: {:03d}; Acc: {:.4f}; Loss: {:.6f}'.format(epoch, acc, loss.item()))

    # save model
    model_stat_dict = dummy_model.state_dict()
    model_path = os.path.join('./', 'dummy_model_{}.pth'.format(args.dataset))
    th.save(model_stat_dict, model_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Dummy model training')
    parser.add_argument('--dataset', type=str, default='syn1', help='The dataset used for training the model.')
    parser.add_argument('--feat_dim', type=int, default=10, help='The feature dimension.')
    parser.add_argument('--hidden_dim', type=int, default=40, help='The hidden dimension.')
    parser.add_argument('--epochs', type=int, default=500, help='The number of epochs.')
    parser.add_argument('--lr', type=float, default=0.001, help='The learning rate.')
    parser.add_argument('--wd', type=float, default=0.0, help='Weight decay.')

    args = parser.parse_args()
    print(args)

    main(args)