main.py 5.2 KB
Newer Older
1
import sys
2
from parser import Parser
3
4

import mxnet as mx
5
import numpy as np
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
8
from dataloader import collate, GraphDataLoader

from dgl.data.gindt import GINDataset
9
from gin import GIN
10
11
from mxnet import gluon, nd
from mxnet.gluon import nn
12
from tqdm import tqdm
13
14
15
16
17
18


def train(args, net, trainloader, trainer, criterion, epoch):
    running_loss = 0
    total_iters = len(trainloader)
    # setup the offset to avoid the overlap with mouse cursor
19
    bar = tqdm(range(total_iters), unit="batch", position=2, file=sys.stdout)
20
21
22
23

    for pos, (graphs, labels) in zip(bar, trainloader):
        # batch graphs will be shipped to device in forward part of model
        labels = labels.as_in_context(args.device)
24
        feat = graphs.ndata["attr"].as_in_context(args.device)
25
26

        with mx.autograd.record():
27
            graphs = graphs.to(args.device)
28
29
30
31
32
33
34
35
36
37
38
            outputs = net(graphs, feat)
            loss = criterion(outputs, labels)
            loss = loss.sum() / len(labels)

        running_loss += loss.asscalar()

        # backprop
        loss.backward()
        trainer.step(batch_size=1)

        # report
39
        bar.set_description("epoch-{}".format(epoch))
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    bar.close()
    # the final batch will be aligned
    running_loss = running_loss / total_iters

    return running_loss


def eval_net(args, net, dataloader, criterion):
    total = 0
    total_loss = 0
    total_correct = 0

    for data in dataloader:
        graphs, labels = data
        labels = labels.as_in_context(args.device)
55
        feat = graphs.ndata["attr"].as_in_context(args.device)
56
57

        total += len(labels)
58
        graphs = graphs.to(args.device)
59
60
        outputs = net(graphs, feat)
        predicted = nd.argmax(outputs, axis=1)
61
        predicted = predicted.astype("int64")
62
63
64
65
66
67

        total_correct += (predicted == labels).sum().asscalar()
        loss = criterion(outputs, labels)
        # crossentropy(reduce=True) for default
        total_loss += loss.sum().asscalar()

68
    loss, acc = 1.0 * total_loss / total, 1.0 * total_correct / total
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    return loss, acc


def main(args):
    # set up seeds, args.seed supported
    mx.random.seed(0)
    np.random.seed(seed=0)

    if args.device >= 0:
        args.device = mx.gpu(args.device)
    else:
        args.device = mx.cpu()

    dataset = GINDataset(args.dataset, not args.learn_eps)

    trainloader, validloader = GraphDataLoader(
86
87
88
89
90
91
92
93
        dataset,
        batch_size=args.batch_size,
        collate_fn=collate,
        seed=args.seed,
        shuffle=True,
        split_name="fold10",
        fold_idx=args.fold_idx,
    ).train_valid_loader()
94
95
96
    # or split_name='rand', split_ratio=0.7

    model = GIN(
97
98
99
100
101
102
103
104
105
106
        args.num_layers,
        args.num_mlp_layers,
        dataset.dim_nfeats,
        args.hidden_dim,
        dataset.gclasses,
        args.final_dropout,
        args.learn_eps,
        args.graph_pooling_type,
        args.neighbor_pooling_type,
    )
107
108
109
110
111
112
    model.initialize(ctx=args.device)

    criterion = gluon.loss.SoftmaxCELoss()

    print(model.collect_params())
    lr_scheduler = mx.lr_scheduler.FactorScheduler(50, 0.5)
113
114
115
    trainer = gluon.Trainer(
        model.collect_params(), "adam", {"lr_scheduler": lr_scheduler}
    )
116
117
118

    # it's not cost-effective to hanle the cursor and init 0
    # https://stackoverflow.com/a/23121189
119
120
121
122
123
124
125
126
127
    tbar = tqdm(
        range(args.epochs), unit="epoch", position=3, ncols=0, file=sys.stdout
    )
    vbar = tqdm(
        range(args.epochs), unit="epoch", position=4, ncols=0, file=sys.stdout
    )
    lrbar = tqdm(
        range(args.epochs), unit="epoch", position=5, ncols=0, file=sys.stdout
    )
128
129
130
131

    for epoch, _, _ in zip(tbar, vbar, lrbar):
        train(args, model, trainloader, trainer, criterion, epoch)

132
        train_loss, train_acc = eval_net(args, model, trainloader, criterion)
133
        tbar.set_description(
134
135
136
137
            "train set - average loss: {:.4f}, accuracy: {:.0f}%".format(
                train_loss, 100.0 * train_acc
            )
        )
138

139
        valid_loss, valid_acc = eval_net(args, model, validloader, criterion)
140
        vbar.set_description(
141
142
143
144
            "valid set - average loss: {:.4f}, accuracy: {:.0f}%".format(
                valid_loss, 100.0 * valid_acc
            )
        )
145
146

        if not args.filename == "":
147
148
149
150
151
152
153
154
155
156
            with open(args.filename, "a") as f:
                f.write(
                    "%s %s %s %s"
                    % (
                        args.dataset,
                        args.learn_eps,
                        args.neighbor_pooling_type,
                        args.graph_pooling_type,
                    )
                )
157
                f.write("\n")
158
159
160
161
                f.write(
                    "%f %f %f %f"
                    % (train_loss, train_acc, valid_loss, valid_acc)
                )
162
163
164
165
                f.write("\n")

        lrbar.set_description(
            "Learning eps with learn_eps={}: {}".format(
166
167
168
169
170
171
172
                args.learn_eps,
                [
                    layer.eps.data(args.device).asscalar()
                    for layer in model.ginlayers
                ],
            )
        )
173
174
175
176
177
178

    tbar.close()
    vbar.close()
    lrbar.close()


179
180
181
if __name__ == "__main__":
    args = Parser(description="GIN").args
    print("show all arguments configuration...")
182
183
    print(args)

184
    main(args)