train.py 7.3 KB
Newer Older
1
import argparse
2
import collections
3
4
5
6
import time
import numpy as np
import torch as th
import torch.nn.functional as F
7
import torch.nn.init as INIT
8
9
10
11
import torch.optim as optim
from torch.utils.data import DataLoader

import dgl
Xiangkun Hu's avatar
Xiangkun Hu committed
12
from dgl.data.tree import SSTDataset
13
14
15

from tree_lstm import TreeLSTM

16
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
Xiangkun Hu's avatar
Xiangkun Hu committed
17

18
def batcher(device):
Da Zheng's avatar
Da Zheng committed
19
20
21
22
23
24
25
26
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
        return SSTBatch(graph=batch_trees,
                        mask=batch_trees.ndata['mask'].to(device),
                        wordid=batch_trees.ndata['x'].to(device),
                        label=batch_trees.ndata['y'].to(device))
    return batcher_dev

27
def main(args):
28
29
30
31
    np.random.seed(args.seed)
    th.manual_seed(args.seed)
    th.cuda.manual_seed(args.seed)

32
33
34
    best_epoch = -1
    best_dev_acc = 0

35
    cuda = args.gpu >= 0
36
    device = th.device('cuda:{}'.format(args.gpu)) if cuda else th.device('cpu')
37
38
    if cuda:
        th.cuda.set_device(args.gpu)
39

Xiangkun Hu's avatar
Xiangkun Hu committed
40
    trainset = SSTDataset()
41
42
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
Da Zheng's avatar
Da Zheng committed
43
                              collate_fn=batcher(device),
44
                              shuffle=True,
45
                              num_workers=0)
Xiangkun Hu's avatar
Xiangkun Hu committed
46
    devset = SSTDataset(mode='dev')
47
48
    dev_loader = DataLoader(dataset=devset,
                            batch_size=100,
Da Zheng's avatar
Da Zheng committed
49
                            collate_fn=batcher(device),
50
51
52
                            shuffle=False,
                            num_workers=0)

Xiangkun Hu's avatar
Xiangkun Hu committed
53
    testset = SSTDataset(mode='test')
54
    test_loader = DataLoader(dataset=testset,
Da Zheng's avatar
Da Zheng committed
55
                             batch_size=100, collate_fn=batcher(device), shuffle=False, num_workers=0)
56

Xiangkun Hu's avatar
Xiangkun Hu committed
57
    model = TreeLSTM(trainset.vocab_size,
58
59
60
                     args.x_size,
                     args.h_size,
                     trainset.num_classes,
61
                     args.dropout,
62
                     cell_type='childsum' if args.child_sum else 'nary',
63
                     pretrained_emb = trainset.pretrained_emb).to(device)
64
    print(model)
Xiangkun Hu's avatar
Xiangkun Hu committed
65
    params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.vocab_size]
66
67
    params_emb = list(model.embedding.parameters())

68
69
70
71
    for p in params_ex_emb:
        if p.dim() > 1:
            INIT.xavier_uniform_(p)

72
73
74
    optimizer = optim.Adagrad([
        {'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay},
        {'params':params_emb, 'lr':0.1*args.lr}])
75

76
77
78
    dur = []
    for epoch in range(args.epochs):
        t_epoch = time.time()
79
80
        model.train()
        for step, batch in enumerate(train_loader):
81
            g = batch.graph.to(device)
82
83
84
            n = g.number_of_nodes()
            h = th.zeros((n, args.h_size)).to(device)
            c = th.zeros((n, args.h_size)).to(device)
85
            if step >= 3:
86
87
                t0 = time.time() # tik

88
            logits = model(batch, g, h, c)
89
            logp = F.log_softmax(logits, 1)
90
91
            loss = F.nll_loss(logp, batch.label, reduction='sum')

92
93
94
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
95

96
            if step >= 3:
97
                dur.append(time.time() - t0) # tok
98
99
100

            if step > 0 and step % args.log_every == 0:
                pred = th.argmax(logits, 1)
101
                acc = th.sum(th.eq(batch.label, pred))
102
                root_ids = [i for i in range(g.number_of_nodes()) if g.out_degree(i)==0]
103
                root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
104

105
106
107
108
                print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
                    epoch, step, loss.item(), 1.0*acc.item()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur)))
        print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch))

109
        # eval on dev set
110
111
112
113
        accs = []
        root_accs = []
        model.eval()
        for step, batch in enumerate(dev_loader):
114
            g = batch.graph.to(device)
115
116
117
118
            n = g.number_of_nodes()
            with th.no_grad():
                h = th.zeros((n, args.h_size)).to(device)
                c = th.zeros((n, args.h_size)).to(device)
119
                logits = model(batch, g, h, c)
120
121
122
123

            pred = th.argmax(logits, 1)
            acc = th.sum(th.eq(batch.label, pred)).item()
            accs.append([acc, len(batch.label)])
124
            root_ids = [i for i in range(g.number_of_nodes()) if g.out_degree(i)==0]
125
126
            root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
            root_accs.append([root_acc, len(root_ids)])
127

128
129
130
131
        dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
        dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
        print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
            epoch, dev_acc, dev_root_acc))
132

133
134
135
136
137
138
139
        if dev_root_acc > best_dev_acc:
            best_dev_acc = dev_root_acc
            best_epoch = epoch
            th.save(model.state_dict(), 'best_{}.pkl'.format(args.seed))
        else:
            if best_epoch <= epoch - 10:
                break
140

141
        # lr decay
142
143
        for param_group in optimizer.param_groups:
            param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
144
145
146
147
148
149
150
151
            print(param_group['lr'])

    # test
    model.load_state_dict(th.load('best_{}.pkl'.format(args.seed)))
    accs = []
    root_accs = []
    model.eval()
    for step, batch in enumerate(test_loader):
152
        g = batch.graph.to(device)
153
154
155
156
        n = g.number_of_nodes()
        with th.no_grad():
            h = th.zeros((n, args.h_size)).to(device)
            c = th.zeros((n, args.h_size)).to(device)
157
            logits = model(batch, g, h, c)
158
159
160
161

        pred = th.argmax(logits, 1)
        acc = th.sum(th.eq(batch.label, pred)).item()
        accs.append([acc, len(batch.label)])
162
        root_ids = [i for i in range(g.number_of_nodes()) if g.out_degree(i)==0]
163
164
        root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
        root_accs.append([root_acc, len(root_ids)])
165

166
167
168
169
170
    test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
    test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
    print('------------------------------------------------------------------------------------')
    print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
        best_epoch, test_acc, test_root_acc))
171
172
173
174

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=-1)
175
    parser.add_argument('--seed', type=int, default=41)
176
    parser.add_argument('--batch-size', type=int, default=20)
177
    parser.add_argument('--child-sum', action='store_true')
178
179
    parser.add_argument('--x-size', type=int, default=300)
    parser.add_argument('--h-size', type=int, default=150)
180
181
182
183
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--log-every', type=int, default=5)
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
184
    parser.add_argument('--dropout', type=float, default=0.5)
185
    args = parser.parse_args()
186
    print(args)
187
    main(args)