"examples/pytorch/vscode:/vscode.git/clone" did not exist on "20734637f0afab9e5ad32ae04c31a4ddc99fa3ba"
main.py 5.2 KB
Newer Older
1
import sys
2
from parser import Parser
3
4

import mxnet as mx
5
6
7
import numpy as np
from dataloader import GraphDataLoader, collate
from gin import GIN
8
9
from mxnet import gluon, nd
from mxnet.gluon import nn
10
from tqdm import tqdm
11
12
13
14
15
16
17
18

from dgl.data.gindt import GINDataset


def train(args, net, trainloader, trainer, criterion, epoch):
    running_loss = 0
    total_iters = len(trainloader)
    # setup the offset to avoid the overlap with mouse cursor
19
    bar = tqdm(range(total_iters), unit="batch", position=2, file=sys.stdout)
20
21
22
23

    for pos, (graphs, labels) in zip(bar, trainloader):
        # batch graphs will be shipped to device in forward part of model
        labels = labels.as_in_context(args.device)
24
        feat = graphs.ndata["attr"].as_in_context(args.device)
25
26

        with mx.autograd.record():
27
            graphs = graphs.to(args.device)
28
29
30
31
32
33
34
35
36
37
38
            outputs = net(graphs, feat)
            loss = criterion(outputs, labels)
            loss = loss.sum() / len(labels)

        running_loss += loss.asscalar()

        # backprop
        loss.backward()
        trainer.step(batch_size=1)

        # report
39
        bar.set_description("epoch-{}".format(epoch))
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    bar.close()
    # the final batch will be aligned
    running_loss = running_loss / total_iters

    return running_loss


def eval_net(args, net, dataloader, criterion):
    total = 0
    total_loss = 0
    total_correct = 0

    for data in dataloader:
        graphs, labels = data
        labels = labels.as_in_context(args.device)
55
        feat = graphs.ndata["attr"].as_in_context(args.device)
56
57

        total += len(labels)
58
        graphs = graphs.to(args.device)
59
60
        outputs = net(graphs, feat)
        predicted = nd.argmax(outputs, axis=1)
61
        predicted = predicted.astype("int64")
62
63
64
65
66
67

        total_correct += (predicted == labels).sum().asscalar()
        loss = criterion(outputs, labels)
        # crossentropy(reduce=True) for default
        total_loss += loss.sum().asscalar()

68
    loss, acc = 1.0 * total_loss / total, 1.0 * total_correct / total
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    return loss, acc


def main(args):

    # set up seeds, args.seed supported
    mx.random.seed(0)
    np.random.seed(seed=0)

    if args.device >= 0:
        args.device = mx.gpu(args.device)
    else:
        args.device = mx.cpu()

    dataset = GINDataset(args.dataset, not args.learn_eps)

    trainloader, validloader = GraphDataLoader(
87
88
89
90
91
92
93
94
        dataset,
        batch_size=args.batch_size,
        collate_fn=collate,
        seed=args.seed,
        shuffle=True,
        split_name="fold10",
        fold_idx=args.fold_idx,
    ).train_valid_loader()
95
96
97
    # or split_name='rand', split_ratio=0.7

    model = GIN(
98
99
100
101
102
103
104
105
106
107
        args.num_layers,
        args.num_mlp_layers,
        dataset.dim_nfeats,
        args.hidden_dim,
        dataset.gclasses,
        args.final_dropout,
        args.learn_eps,
        args.graph_pooling_type,
        args.neighbor_pooling_type,
    )
108
109
110
111
112
113
    model.initialize(ctx=args.device)

    criterion = gluon.loss.SoftmaxCELoss()

    print(model.collect_params())
    lr_scheduler = mx.lr_scheduler.FactorScheduler(50, 0.5)
114
115
116
    trainer = gluon.Trainer(
        model.collect_params(), "adam", {"lr_scheduler": lr_scheduler}
    )
117
118
119

    # it's not cost-effective to hanle the cursor and init 0
    # https://stackoverflow.com/a/23121189
120
121
122
123
124
125
126
127
128
    tbar = tqdm(
        range(args.epochs), unit="epoch", position=3, ncols=0, file=sys.stdout
    )
    vbar = tqdm(
        range(args.epochs), unit="epoch", position=4, ncols=0, file=sys.stdout
    )
    lrbar = tqdm(
        range(args.epochs), unit="epoch", position=5, ncols=0, file=sys.stdout
    )
129
130
131
132

    for epoch, _, _ in zip(tbar, vbar, lrbar):
        train(args, model, trainloader, trainer, criterion, epoch)

133
        train_loss, train_acc = eval_net(args, model, trainloader, criterion)
134
        tbar.set_description(
135
136
137
138
            "train set - average loss: {:.4f}, accuracy: {:.0f}%".format(
                train_loss, 100.0 * train_acc
            )
        )
139

140
        valid_loss, valid_acc = eval_net(args, model, validloader, criterion)
141
        vbar.set_description(
142
143
144
145
            "valid set - average loss: {:.4f}, accuracy: {:.0f}%".format(
                valid_loss, 100.0 * valid_acc
            )
        )
146
147

        if not args.filename == "":
148
149
150
151
152
153
154
155
156
157
            with open(args.filename, "a") as f:
                f.write(
                    "%s %s %s %s"
                    % (
                        args.dataset,
                        args.learn_eps,
                        args.neighbor_pooling_type,
                        args.graph_pooling_type,
                    )
                )
158
                f.write("\n")
159
160
161
162
                f.write(
                    "%f %f %f %f"
                    % (train_loss, train_acc, valid_loss, valid_acc)
                )
163
164
165
166
                f.write("\n")

        lrbar.set_description(
            "Learning eps with learn_eps={}: {}".format(
167
168
169
170
171
172
173
                args.learn_eps,
                [
                    layer.eps.data(args.device).asscalar()
                    for layer in model.ginlayers
                ],
            )
        )
174
175
176
177
178
179

    tbar.close()
    vbar.close()
    lrbar.close()


180
181
182
if __name__ == "__main__":
    args = Parser(description="GIN").args
    print("show all arguments configuration...")
183
184
    print(args)

185
    main(args)