train.py 7.95 KB
Newer Older
1
import argparse
2
import collections
3
import time
4

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
8
9
import numpy as np
import torch as th
import torch.nn.functional as F
10
import torch.nn.init as INIT
11
import torch.optim as optim
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
from dgl.data.tree import SSTDataset
13
from torch.utils.data import DataLoader
14
from tree_lstm import TreeLSTM
15

16
17
18
SSTBatch = collections.namedtuple(
    "SSTBatch", ["graph", "mask", "wordid", "label"]
)
19

Xiangkun Hu's avatar
Xiangkun Hu committed
20

21
def batcher(device):
Da Zheng's avatar
Da Zheng committed
22
23
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
24
25
26
27
28
29
30
        return SSTBatch(
            graph=batch_trees,
            mask=batch_trees.ndata["mask"].to(device),
            wordid=batch_trees.ndata["x"].to(device),
            label=batch_trees.ndata["y"].to(device),
        )

Da Zheng's avatar
Da Zheng committed
31
32
    return batcher_dev

33

34
def main(args):
35
36
37
38
    np.random.seed(args.seed)
    th.manual_seed(args.seed)
    th.cuda.manual_seed(args.seed)

39
40
41
    best_epoch = -1
    best_dev_acc = 0

42
    cuda = args.gpu >= 0
43
    device = th.device("cuda:{}".format(args.gpu)) if cuda else th.device("cpu")
44
45
    if cuda:
        th.cuda.set_device(args.gpu)
46

Xiangkun Hu's avatar
Xiangkun Hu committed
47
    trainset = SSTDataset()
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    train_loader = DataLoader(
        dataset=trainset,
        batch_size=args.batch_size,
        collate_fn=batcher(device),
        shuffle=True,
        num_workers=0,
    )
    devset = SSTDataset(mode="dev")
    dev_loader = DataLoader(
        dataset=devset,
        batch_size=100,
        collate_fn=batcher(device),
        shuffle=False,
        num_workers=0,
    )

    testset = SSTDataset(mode="test")
    test_loader = DataLoader(
        dataset=testset,
        batch_size=100,
        collate_fn=batcher(device),
        shuffle=False,
        num_workers=0,
    )

    model = TreeLSTM(
        trainset.vocab_size,
        args.x_size,
        args.h_size,
        trainset.num_classes,
        args.dropout,
        cell_type="childsum" if args.child_sum else "nary",
        pretrained_emb=trainset.pretrained_emb,
    ).to(device)
82
    print(model)
83
84
85
86
87
    params_ex_emb = [
        x
        for x in list(model.parameters())
        if x.requires_grad and x.size(0) != trainset.vocab_size
    ]
88
89
    params_emb = list(model.embedding.parameters())

90
91
92
93
    for p in params_ex_emb:
        if p.dim() > 1:
            INIT.xavier_uniform_(p)

94
95
96
97
98
99
100
101
102
103
    optimizer = optim.Adagrad(
        [
            {
                "params": params_ex_emb,
                "lr": args.lr,
                "weight_decay": args.weight_decay,
            },
            {"params": params_emb, "lr": 0.1 * args.lr},
        ]
    )
104

105
106
107
    dur = []
    for epoch in range(args.epochs):
        t_epoch = time.time()
108
109
        model.train()
        for step, batch in enumerate(train_loader):
110
            g = batch.graph.to(device)
111
112
113
            n = g.number_of_nodes()
            h = th.zeros((n, args.h_size)).to(device)
            c = th.zeros((n, args.h_size)).to(device)
114
            if step >= 3:
115
                t0 = time.time()  # tik
116

117
            logits = model(batch, g, h, c)
118
            logp = F.log_softmax(logits, 1)
119
            loss = F.nll_loss(logp, batch.label, reduction="sum")
120

121
122
123
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
124

125
            if step >= 3:
126
                dur.append(time.time() - t0)  # tok
127
128
129

            if step > 0 and step % args.log_every == 0:
                pred = th.argmax(logits, 1)
130
                acc = th.sum(th.eq(batch.label, pred))
131
132
133
                root_ids = [
                    i
                    for i in range(g.number_of_nodes())
134
                    if g.out_degrees(i) == 0
135
136
137
138
139
                ]
                root_acc = np.sum(
                    batch.label.cpu().data.numpy()[root_ids]
                    == pred.cpu().data.numpy()[root_ids]
                )
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
                print(
                    "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
                        epoch,
                        step,
                        loss.item(),
                        1.0 * acc.item() / len(batch.label),
                        1.0 * root_acc / len(root_ids),
                        np.mean(dur),
                    )
                )
        print(
            "Epoch {:05d} training time {:.4f}s".format(
                epoch, time.time() - t_epoch
            )
        )
156

157
        # eval on dev set
158
159
160
161
        accs = []
        root_accs = []
        model.eval()
        for step, batch in enumerate(dev_loader):
162
            g = batch.graph.to(device)
163
164
165
166
            n = g.number_of_nodes()
            with th.no_grad():
                h = th.zeros((n, args.h_size)).to(device)
                c = th.zeros((n, args.h_size)).to(device)
167
                logits = model(batch, g, h, c)
168
169
170
171

            pred = th.argmax(logits, 1)
            acc = th.sum(th.eq(batch.label, pred)).item()
            accs.append([acc, len(batch.label)])
172
            root_ids = [
173
                i for i in range(g.number_of_nodes()) if g.out_degrees(i) == 0
174
175
176
177
178
            ]
            root_acc = np.sum(
                batch.label.cpu().data.numpy()[root_ids]
                == pred.cpu().data.numpy()[root_ids]
            )
179
            root_accs.append([root_acc, len(root_ids)])
180

181
182
183
184
185
186
187
188
189
190
191
192
193
        dev_acc = (
            1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
        )
        dev_root_acc = (
            1.0
            * np.sum([x[0] for x in root_accs])
            / np.sum([x[1] for x in root_accs])
        )
        print(
            "Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
                epoch, dev_acc, dev_root_acc
            )
        )
194

195
196
197
        if dev_root_acc > best_dev_acc:
            best_dev_acc = dev_root_acc
            best_epoch = epoch
198
            th.save(model.state_dict(), "best_{}.pkl".format(args.seed))
199
200
201
        else:
            if best_epoch <= epoch - 10:
                break
202

203
        # lr decay
204
        for param_group in optimizer.param_groups:
205
206
            param_group["lr"] = max(1e-5, param_group["lr"] * 0.99)  # 10
            print(param_group["lr"])
207
208

    # test
209
    model.load_state_dict(th.load("best_{}.pkl".format(args.seed)))
210
211
212
213
    accs = []
    root_accs = []
    model.eval()
    for step, batch in enumerate(test_loader):
214
        g = batch.graph.to(device)
215
216
217
218
        n = g.number_of_nodes()
        with th.no_grad():
            h = th.zeros((n, args.h_size)).to(device)
            c = th.zeros((n, args.h_size)).to(device)
219
            logits = model(batch, g, h, c)
220
221
222
223

        pred = th.argmax(logits, 1)
        acc = th.sum(th.eq(batch.label, pred)).item()
        accs.append([acc, len(batch.label)])
224
        root_ids = [
225
            i for i in range(g.number_of_nodes()) if g.out_degrees(i) == 0
226
227
228
229
230
        ]
        root_acc = np.sum(
            batch.label.cpu().data.numpy()[root_ids]
            == pred.cpu().data.numpy()[root_ids]
        )
231
        root_accs.append([root_acc, len(root_ids)])
232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    test_acc = 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])
    test_root_acc = (
        1.0
        * np.sum([x[0] for x in root_accs])
        / np.sum([x[1] for x in root_accs])
    )
    print(
        "------------------------------------------------------------------------------------"
    )
    print(
        "Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
            best_epoch, test_acc, test_root_acc
        )
    )

248

249
if __name__ == "__main__":
250
    parser = argparse.ArgumentParser()
251
252
253
254
255
256
257
258
259
260
261
    parser.add_argument("--gpu", type=int, default=-1)
    parser.add_argument("--seed", type=int, default=41)
    parser.add_argument("--batch-size", type=int, default=20)
    parser.add_argument("--child-sum", action="store_true")
    parser.add_argument("--x-size", type=int, default=300)
    parser.add_argument("--h-size", type=int, default=150)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--log-every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.05)
    parser.add_argument("--weight-decay", type=float, default=1e-4)
    parser.add_argument("--dropout", type=float, default=0.5)
262
    args = parser.parse_args()
263
    print(args)
264
    main(args)