train.py 9.24 KB
Newer Older
1
import argparse
2
3
import collections
import os
4
5
6
7
import time
import warnings
import zipfile

8
9
os.environ["DGLBACKEND"] = "mxnet"
os.environ["MXNET_GPU_MEM_POOL_TYPE"] = "Round"
10
11

import mxnet as mx
12
import numpy as np
13
from mxnet import gluon
14
from tree_lstm import TreeLSTM
15
16
17
18

import dgl
import dgl.data as data

19
20
21
SSTBatch = collections.namedtuple(
    "SSTBatch", ["graph", "mask", "wordid", "label"]
)
22

Xiangkun Hu's avatar
Xiangkun Hu committed
23

24
25
26
def batcher(ctx):
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
27
28
29
30
31
32
33
        return SSTBatch(
            graph=batch_trees,
            mask=batch_trees.ndata["mask"].as_in_context(ctx),
            wordid=batch_trees.ndata["x"].as_in_context(ctx),
            label=batch_trees.ndata["y"].as_in_context(ctx),
        )

34
35
    return batcher_dev

36

37
def prepare_glove():
38
39
40
41
42
43
44
45
46
47
48
49
    if not (
        os.path.exists("glove.840B.300d.txt")
        and data.utils.check_sha1(
            "glove.840B.300d.txt",
            sha1_hash="294b9f37fa64cce31f9ebb409c266fc379527708",
        )
    ):
        zip_path = data.utils.download(
            "http://nlp.stanford.edu/data/glove.840B.300d.zip",
            sha1_hash="8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d",
        )
        with zipfile.ZipFile(zip_path, "r") as zf:
50
            zf.extractall()
51
52
53
54
55
56
57
58
59
        if not data.utils.check_sha1(
            "glove.840B.300d.txt",
            sha1_hash="294b9f37fa64cce31f9ebb409c266fc379527708",
        ):
            warnings.warn(
                "The downloaded glove embedding file checksum mismatch. File content "
                "may be corrupted."
            )

60
61
62
63
64
65
66
67
68
69
70
71
72

def main(args):
    np.random.seed(args.seed)
    mx.random.seed(args.seed)

    best_epoch = -1
    best_dev_acc = 0

    cuda = args.gpu >= 0
    if cuda:
        if args.gpu in mx.test_utils.list_gpus():
            ctx = mx.gpu(args.gpu)
        else:
73
74
75
76
77
            print(
                "Requested GPU id {} was not found. Defaulting to CPU implementation".format(
                    args.gpu
                )
            )
78
            ctx = mx.cpu()
Xiangkun Hu's avatar
Xiangkun Hu committed
79
80
    else:
        ctx = mx.cpu()
81
82
83
84

    if args.use_glove:
        prepare_glove()

Xiangkun Hu's avatar
Xiangkun Hu committed
85
    trainset = data.SSTDataset()
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    train_loader = gluon.data.DataLoader(
        dataset=trainset,
        batch_size=args.batch_size,
        batchify_fn=batcher(ctx),
        shuffle=True,
        num_workers=0,
    )
    devset = data.SSTDataset(mode="dev")
    dev_loader = gluon.data.DataLoader(
        dataset=devset,
        batch_size=100,
        batchify_fn=batcher(ctx),
        shuffle=True,
        num_workers=0,
    )

    testset = data.SSTDataset(mode="test")
    test_loader = gluon.data.DataLoader(
        dataset=testset,
        batch_size=100,
        batchify_fn=batcher(ctx),
        shuffle=False,
        num_workers=0,
    )

    model = TreeLSTM(
        trainset.vocab_size,
        args.x_size,
        args.h_size,
        trainset.num_classes,
        args.dropout,
        cell_type="childsum" if args.child_sum else "nary",
        pretrained_emb=trainset.pretrained_emb,
        ctx=ctx,
    )
121
    print(model)
122
123
124
125
126
    params_ex_emb = [
        x
        for x in model.collect_params().values()
        if x.grad_req != "null" and x.shape[0] != trainset.vocab_size
    ]
127
128
129
130
131
132
    params_emb = list(model.embedding.collect_params().values())
    for p in params_emb:
        p.lr_mult = 0.1

    model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx)
    model.hybridize()
133
134
135
136
137
138
139
140
141
142
    trainer = gluon.Trainer(
        model.collect_params("^(?!embedding).*$"),
        "adagrad",
        {"learning_rate": args.lr, "wd": args.weight_decay},
    )
    trainer_emb = gluon.Trainer(
        model.collect_params("^embedding.*$"),
        "adagrad",
        {"learning_rate": args.lr},
    )
143
144
145
146
147
148
149
150
151
152
153
154
155

    dur = []
    L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)
    for epoch in range(args.epochs):
        t_epoch = time.time()
        for step, batch in enumerate(train_loader):
            g = batch.graph
            n = g.number_of_nodes()

            # TODO begin_states function?
            h = mx.nd.zeros((n, args.h_size), ctx=ctx)
            c = mx.nd.zeros((n, args.h_size), ctx=ctx)
            if step >= 3:
156
                t0 = time.time()  # tik
157
158
159
160
161
162
163
164
165
            with mx.autograd.record():
                pred = model(batch, h, c)
                loss = L(pred, batch.label)

            loss.backward()
            trainer.step(args.batch_size)
            trainer_emb.step(args.batch_size)

            if step >= 3:
166
                dur.append(time.time() - t0)  # tok
167
168

            if step > 0 and step % args.log_every == 0:
Da Zheng's avatar
Da Zheng committed
169
                pred = pred.argmax(axis=1).astype(batch.label.dtype)
170
                acc = (batch.label == pred).sum()
171
172
173
174
175
176
177
178
                root_ids = [
                    i
                    for i in range(batch.graph.number_of_nodes())
                    if batch.graph.out_degree(i) == 0
                ]
                root_acc = np.sum(
                    batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
                )
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
                print(
                    "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
                        epoch,
                        step,
                        loss.sum().asscalar(),
                        1.0 * acc.asscalar() / len(batch.label),
                        1.0 * root_acc / len(root_ids),
                        np.mean(dur),
                    )
                )
        print(
            "Epoch {:05d} training time {:.4f}s".format(
                epoch, time.time() - t_epoch
            )
        )
195
196
197
198
199
200
201
202
203

        # eval on dev set
        accs = []
        root_accs = []
        for step, batch in enumerate(dev_loader):
            g = batch.graph
            n = g.number_of_nodes()
            h = mx.nd.zeros((n, args.h_size), ctx=ctx)
            c = mx.nd.zeros((n, args.h_size), ctx=ctx)
Da Zheng's avatar
Da Zheng committed
204
            pred = model(batch, h, c).argmax(1).astype(batch.label.dtype)
205
206
207

            acc = (batch.label == pred).sum().asscalar()
            accs.append([acc, len(batch.label)])
208
209
210
211
212
213
214
215
            root_ids = [
                i
                for i in range(batch.graph.number_of_nodes())
                if batch.graph.out_degree(i) == 0
            ]
            root_acc = np.sum(
                batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
            )
216
217
            root_accs.append([root_acc, len(root_ids)])

218
219
220
221
222
223
224
225
226
227
228
229
230
        dev_acc = (
            1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
        )
        dev_root_acc = (
            1.0
            * np.sum([x[0] for x in root_accs])
            / np.sum([x[1] for x in root_accs])
        )
        print(
            "Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
                epoch, dev_acc, dev_root_acc
            )
        )
231
232
233
234

        if dev_root_acc > best_dev_acc:
            best_dev_acc = dev_root_acc
            best_epoch = epoch
235
            model.save_parameters("best_{}.params".format(args.seed))
236
237
238
239
240
        else:
            if best_epoch <= epoch - 10:
                break

        # lr decay
241
        trainer.set_learning_rate(max(1e-5, trainer.learning_rate * 0.99))
242
        print(trainer.learning_rate)
243
244
245
        trainer_emb.set_learning_rate(
            max(1e-5, trainer_emb.learning_rate * 0.99)
        )
246
247
248
        print(trainer_emb.learning_rate)

    # test
249
    model.load_parameters("best_{}.params".format(args.seed))
250
251
252
253
254
255
256
    accs = []
    root_accs = []
    for step, batch in enumerate(test_loader):
        g = batch.graph
        n = g.number_of_nodes()
        h = mx.nd.zeros((n, args.h_size), ctx=ctx)
        c = mx.nd.zeros((n, args.h_size), ctx=ctx)
Da Zheng's avatar
Da Zheng committed
257
        pred = model(batch, h, c).argmax(axis=1).astype(batch.label.dtype)
258
259
260

        acc = (batch.label == pred).sum().asscalar()
        accs.append([acc, len(batch.label)])
261
262
263
264
265
266
267
268
        root_ids = [
            i
            for i in range(batch.graph.number_of_nodes())
            if batch.graph.out_degree(i) == 0
        ]
        root_acc = np.sum(
            batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
        )
269
270
        root_accs.append([root_acc, len(root_ids)])

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    test_acc = 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
    test_root_acc = (
        1.0
        * np.sum([x[0] for x in root_accs])
        / np.sum([x[1] for x in root_accs])
    )
    print(
        "------------------------------------------------------------------------------------"
    )
    print(
        "Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
            best_epoch, test_acc, test_root_acc
        )
    )

286

287
if __name__ == "__main__":
288
    parser = argparse.ArgumentParser()
289
290
291
292
293
294
295
296
297
298
299
300
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--seed", type=int, default=41)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--child-sum", action="store_true")
    parser.add_argument("--x-size", type=int, default=300)
    parser.add_argument("--h-size", type=int, default=150)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--log-every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.05)
    parser.add_argument("--weight-decay", type=float, default=1e-4)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--use-glove", action="store_true")
301
302
303
    args = parser.parse_args()
    print(args)
    main(args)