train.py 3.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import argparse
import time
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import dgl
import dgl.data as data
11
import dgl.ndarray as nd
12
13
14
15
16
17
18

from tree_lstm import TreeLSTM

def main(args):
    cuda = args.gpu >= 0
    if cuda:
        th.cuda.set_device(args.gpu)
19
20
21
    def _batcher(trees):
        bg = dgl.batch(trees)
        if cuda:
22
23
            for key in bg.node_attr_schemes().keys():
                bg.ndata[key] = bg.ndata[key].cuda()
24
        return bg
25
26
27
    trainset = data.SST()
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
28
                              collate_fn=_batcher,
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
                              shuffle=False,
                              num_workers=0)
    #testset = data.SST(mode='test')
    #test_loader = DataLoader(dataset=testset,
    #                         batch_size=100,
    #                         collate_fn=data.SST.batcher,
    #                         shuffle=False,
    #                         num_workers=0)

    model = TreeLSTM(trainset.num_vocabs,
                     args.x_size,
                     args.h_size,
                     trainset.num_classes,
                     args.dropout)
    if cuda:
        model.cuda()
45
46
47
        zero_initializer = lambda shape : th.zeros(shape).cuda()
    else:
        zero_initializer = th.zeros
48
49
50
51
52
53
54
    print(model)
    optimizer = optim.Adagrad(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
    dur = []
    for epoch in range(args.epochs):
        t_epoch = time.time()
55
        for step, graph in enumerate(train_loader):
56
57
            if step >= 3:
                t0 = time.time()
58
            label = graph.ndata.pop('y')
59
            # traverse graph
60
            logits = model(graph, zero_initializer, train=True)
61
            logp = F.log_softmax(logits, 1)
62
            loss = F.nll_loss(logp, label)
63
64
65
66
67
68
69
70
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step >= 3:
                dur.append(time.time() - t0)

            if step > 0 and step % args.log_every == 0:
                pred = th.argmax(logits, 1)
71
                acc = th.sum(th.eq(label, pred))
72
73
74
                mean_dur = np.mean(dur)
                print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
                      "Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
Minjie Wang's avatar
Minjie Wang committed
75
                    epoch, step, loss.item(), acc.item() / len(label),
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
                    mean_dur, args.batch_size / mean_dur))
        print("Epoch time(s):", time.time() - t_epoch)

        # test
        #for step, batch in enumerate(test_loader):
        #    g = batch.graph
        #    n = g.number_of_nodes()
        #    x = th.zeros((n, args.x_size))
        #    h = th.zeros((n, args.h_size))
        #    c = th.zeros((n, args.h_size))
        #    logits = model(batch, x, h, c, train=True)
        #    pred = th.argmax(logits, 1)
        #    acc = th.sum(th.eq(batch.label, pred)) / len(batch.label)
        #    print(acc.item())

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=-1)
    parser.add_argument('--batch-size', type=int, default=25)
    parser.add_argument('--x-size', type=int, default=256)
    parser.add_argument('--h-size', type=int, default=256)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--log-every', type=int, default=5)
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--n-ary', type=int, default=2)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--dropout', type=float, default=0.5)
    args = parser.parse_args()
    main(args)