train_main.py 1.73 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
6
7
import os
import argparse

import torch as th
import torch.nn as nn

from dgl import save_graphs
8
from models import Model
KounianhuaDu's avatar
KounianhuaDu committed
9

10
from dgl.data import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset
KounianhuaDu's avatar
KounianhuaDu committed
11

12
13
14
15
16
17
18
19
20
21
22
23
def main(args):
    if args.dataset == 'BAShape':
        dataset = BAShapeDataset(seed=0)
    elif args.dataset == 'BACommunity':
        dataset = BACommunityDataset(seed=0)
    elif args.dataset == 'TreeCycle':
        dataset = TreeCycleDataset(seed=0)
    elif args.dataset == 'TreeGrid':
        dataset = TreeGridDataset(seed=0)

    graph = dataset[0]
    labels = graph.ndata['label']
KounianhuaDu's avatar
KounianhuaDu committed
24
    n_feats = graph.ndata['feat']
25
    num_classes = dataset.num_classes
KounianhuaDu's avatar
KounianhuaDu committed
26

27
    model = Model(n_feats.shape[-1], num_classes)
KounianhuaDu's avatar
KounianhuaDu committed
28
    loss_fn = nn.CrossEntropyLoss()
29
    optim = th.optim.Adam(model.parameters(), lr=0.001)
KounianhuaDu's avatar
KounianhuaDu committed
30

31
32
33
34
35
    for epoch in range(500):
        model.train()
        # For demo purpose, we train the model on all datapoints
        # In practice, you should train only on the training datapoints
        logits = model(graph, n_feats)
KounianhuaDu's avatar
KounianhuaDu committed
36
37
        loss = loss_fn(logits, labels)
        acc = th.sum(logits.argmax(dim=1) == labels).item() / len(labels)
38

KounianhuaDu's avatar
KounianhuaDu committed
39
40
41
42
        optim.zero_grad()
        loss.backward()
        optim.step()

43
        print(f'In Epoch: {epoch}; Acc: {acc}; Loss: {loss.item()}')
KounianhuaDu's avatar
KounianhuaDu committed
44

45
46
    model_stat_dict = model.state_dict()
    model_path = os.path.join('./', f'model_{args.dataset}.pth')
KounianhuaDu's avatar
KounianhuaDu committed
47
48
49
50
    th.save(model_stat_dict, model_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Dummy model training')
51
52
    parser.add_argument('--dataset', type=str, default='BAShape',
                        choices=['BAShape', 'BACommunity', 'TreeCycle', 'TreeGrid'])
KounianhuaDu's avatar
KounianhuaDu committed
53
54
55
56
    args = parser.parse_args()
    print(args)

    main(args)