train.py 9.25 KB
Newer Older
1
import argparse
2
3
import collections
import os
4
5
6
7
import time
import warnings
import zipfile

8
9
os.environ["DGLBACKEND"] = "mxnet"
os.environ["MXNET_GPU_MEM_POOL_TYPE"] = "Round"
10

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
12
import dgl
import dgl.data as data
13
import mxnet as mx
14
import numpy as np
15
from mxnet import gluon
16
from tree_lstm import TreeLSTM
17

18
19
20
SSTBatch = collections.namedtuple(
    "SSTBatch", ["graph", "mask", "wordid", "label"]
)
21

Xiangkun Hu's avatar
Xiangkun Hu committed
22

23
24
25
def batcher(ctx):
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
26
27
28
29
30
31
32
        return SSTBatch(
            graph=batch_trees,
            mask=batch_trees.ndata["mask"].as_in_context(ctx),
            wordid=batch_trees.ndata["x"].as_in_context(ctx),
            label=batch_trees.ndata["y"].as_in_context(ctx),
        )

33
34
    return batcher_dev

35

36
def prepare_glove():
37
38
39
40
41
42
43
44
45
46
47
48
    if not (
        os.path.exists("glove.840B.300d.txt")
        and data.utils.check_sha1(
            "glove.840B.300d.txt",
            sha1_hash="294b9f37fa64cce31f9ebb409c266fc379527708",
        )
    ):
        zip_path = data.utils.download(
            "http://nlp.stanford.edu/data/glove.840B.300d.zip",
            sha1_hash="8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d",
        )
        with zipfile.ZipFile(zip_path, "r") as zf:
49
            zf.extractall()
50
51
52
53
54
55
56
57
58
        if not data.utils.check_sha1(
            "glove.840B.300d.txt",
            sha1_hash="294b9f37fa64cce31f9ebb409c266fc379527708",
        ):
            warnings.warn(
                "The downloaded glove embedding file checksum mismatch. File content "
                "may be corrupted."
            )

59
60
61
62
63
64
65
66
67
68
69
70
71

def main(args):
    np.random.seed(args.seed)
    mx.random.seed(args.seed)

    best_epoch = -1
    best_dev_acc = 0

    cuda = args.gpu >= 0
    if cuda:
        if args.gpu in mx.test_utils.list_gpus():
            ctx = mx.gpu(args.gpu)
        else:
72
73
74
75
76
            print(
                "Requested GPU id {} was not found. Defaulting to CPU implementation".format(
                    args.gpu
                )
            )
77
            ctx = mx.cpu()
Xiangkun Hu's avatar
Xiangkun Hu committed
78
79
    else:
        ctx = mx.cpu()
80
81
82
83

    if args.use_glove:
        prepare_glove()

Xiangkun Hu's avatar
Xiangkun Hu committed
84
    trainset = data.SSTDataset()
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    train_loader = gluon.data.DataLoader(
        dataset=trainset,
        batch_size=args.batch_size,
        batchify_fn=batcher(ctx),
        shuffle=True,
        num_workers=0,
    )
    devset = data.SSTDataset(mode="dev")
    dev_loader = gluon.data.DataLoader(
        dataset=devset,
        batch_size=100,
        batchify_fn=batcher(ctx),
        shuffle=True,
        num_workers=0,
    )

    testset = data.SSTDataset(mode="test")
    test_loader = gluon.data.DataLoader(
        dataset=testset,
        batch_size=100,
        batchify_fn=batcher(ctx),
        shuffle=False,
        num_workers=0,
    )

    model = TreeLSTM(
        trainset.vocab_size,
        args.x_size,
        args.h_size,
        trainset.num_classes,
        args.dropout,
        cell_type="childsum" if args.child_sum else "nary",
        pretrained_emb=trainset.pretrained_emb,
        ctx=ctx,
    )
120
    print(model)
121
122
123
124
125
    params_ex_emb = [
        x
        for x in model.collect_params().values()
        if x.grad_req != "null" and x.shape[0] != trainset.vocab_size
    ]
126
127
128
129
130
131
    params_emb = list(model.embedding.collect_params().values())
    for p in params_emb:
        p.lr_mult = 0.1

    model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx)
    model.hybridize()
132
133
134
135
136
137
138
139
140
141
    trainer = gluon.Trainer(
        model.collect_params("^(?!embedding).*$"),
        "adagrad",
        {"learning_rate": args.lr, "wd": args.weight_decay},
    )
    trainer_emb = gluon.Trainer(
        model.collect_params("^embedding.*$"),
        "adagrad",
        {"learning_rate": args.lr},
    )
142
143
144
145
146
147
148
149
150
151
152
153
154

    dur = []
    L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)
    for epoch in range(args.epochs):
        t_epoch = time.time()
        for step, batch in enumerate(train_loader):
            g = batch.graph
            n = g.number_of_nodes()

            # TODO begin_states function?
            h = mx.nd.zeros((n, args.h_size), ctx=ctx)
            c = mx.nd.zeros((n, args.h_size), ctx=ctx)
            if step >= 3:
155
                t0 = time.time()  # tik
156
157
158
159
160
161
162
163
164
            with mx.autograd.record():
                pred = model(batch, h, c)
                loss = L(pred, batch.label)

            loss.backward()
            trainer.step(args.batch_size)
            trainer_emb.step(args.batch_size)

            if step >= 3:
165
                dur.append(time.time() - t0)  # tok
166
167

            if step > 0 and step % args.log_every == 0:
Da Zheng's avatar
Da Zheng committed
168
                pred = pred.argmax(axis=1).astype(batch.label.dtype)
169
                acc = (batch.label == pred).sum()
170
171
172
                root_ids = [
                    i
                    for i in range(batch.graph.number_of_nodes())
173
                    if batch.graph.out_degrees(i) == 0
174
175
176
177
                ]
                root_acc = np.sum(
                    batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
                )
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
                print(
                    "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
                        epoch,
                        step,
                        loss.sum().asscalar(),
                        1.0 * acc.asscalar() / len(batch.label),
                        1.0 * root_acc / len(root_ids),
                        np.mean(dur),
                    )
                )
        print(
            "Epoch {:05d} training time {:.4f}s".format(
                epoch, time.time() - t_epoch
            )
        )
194
195
196
197
198
199
200
201
202

        # eval on dev set
        accs = []
        root_accs = []
        for step, batch in enumerate(dev_loader):
            g = batch.graph
            n = g.number_of_nodes()
            h = mx.nd.zeros((n, args.h_size), ctx=ctx)
            c = mx.nd.zeros((n, args.h_size), ctx=ctx)
Da Zheng's avatar
Da Zheng committed
203
            pred = model(batch, h, c).argmax(1).astype(batch.label.dtype)
204
205
206

            acc = (batch.label == pred).sum().asscalar()
            accs.append([acc, len(batch.label)])
207
208
209
            root_ids = [
                i
                for i in range(batch.graph.number_of_nodes())
210
                if batch.graph.out_degrees(i) == 0
211
212
213
214
            ]
            root_acc = np.sum(
                batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
            )
215
216
            root_accs.append([root_acc, len(root_ids)])

217
218
219
220
221
222
223
224
225
226
227
228
229
        dev_acc = (
            1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
        )
        dev_root_acc = (
            1.0
            * np.sum([x[0] for x in root_accs])
            / np.sum([x[1] for x in root_accs])
        )
        print(
            "Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
                epoch, dev_acc, dev_root_acc
            )
        )
230
231
232
233

        if dev_root_acc > best_dev_acc:
            best_dev_acc = dev_root_acc
            best_epoch = epoch
234
            model.save_parameters("best_{}.params".format(args.seed))
235
236
237
238
239
        else:
            if best_epoch <= epoch - 10:
                break

        # lr decay
240
        trainer.set_learning_rate(max(1e-5, trainer.learning_rate * 0.99))
241
        print(trainer.learning_rate)
242
243
244
        trainer_emb.set_learning_rate(
            max(1e-5, trainer_emb.learning_rate * 0.99)
        )
245
246
247
        print(trainer_emb.learning_rate)

    # test
248
    model.load_parameters("best_{}.params".format(args.seed))
249
250
251
252
253
254
255
    accs = []
    root_accs = []
    for step, batch in enumerate(test_loader):
        g = batch.graph
        n = g.number_of_nodes()
        h = mx.nd.zeros((n, args.h_size), ctx=ctx)
        c = mx.nd.zeros((n, args.h_size), ctx=ctx)
Da Zheng's avatar
Da Zheng committed
256
        pred = model(batch, h, c).argmax(axis=1).astype(batch.label.dtype)
257
258
259

        acc = (batch.label == pred).sum().asscalar()
        accs.append([acc, len(batch.label)])
260
261
262
        root_ids = [
            i
            for i in range(batch.graph.number_of_nodes())
263
            if batch.graph.out_degrees(i) == 0
264
265
266
267
        ]
        root_acc = np.sum(
            batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]
        )
268
269
        root_accs.append([root_acc, len(root_ids)])

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    test_acc = 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
    test_root_acc = (
        1.0
        * np.sum([x[0] for x in root_accs])
        / np.sum([x[1] for x in root_accs])
    )
    print(
        "------------------------------------------------------------------------------------"
    )
    print(
        "Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
            best_epoch, test_acc, test_root_acc
        )
    )

285

286
if __name__ == "__main__":
287
    parser = argparse.ArgumentParser()
288
289
290
291
292
293
294
295
296
297
298
299
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--seed", type=int, default=41)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--child-sum", action="store_true")
    parser.add_argument("--x-size", type=int, default=300)
    parser.add_argument("--h-size", type=int, default=150)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--log-every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.05)
    parser.add_argument("--weight-decay", type=float, default=1e-4)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--use-glove", action="store_true")
300
301
302
    args = parser.parse_args()
    print(args)
    main(args)