train.py 4.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import argparse
import time
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import dgl
import dgl.data as data

from tree_lstm import TreeLSTM

def _batch_to_cuda(batch):
    return data.SSTBatch(graph=batch.graph,
                         nid_with_word = batch.nid_with_word.cuda(),
                         wordid = batch.wordid.cuda(),
                         label = batch.label.cuda())

import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args):
    n = g.number_of_nodes()
    if cuda:
24
        adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu))
25
26
        mask = th.ones((n, 1)).cuda()
    else:
27
        adjmat = g.cached_graph.adjmat().get(ctx.cpu())
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        mask = th.ones((n, 1))
    degree = th.spmm(adjmat, mask)
    while th.sum(mask) != 0.:
        v = (degree == 0.).float()
        v = v * mask
        mask = mask - v
        frontier = th.squeeze(th.squeeze(v).nonzero(), 1)
        yield frontier
        degree -= th.spmm(adjmat, v)

def main(args):
    cuda = args.gpu >= 0
    if cuda:
        th.cuda.set_device(args.gpu)
    trainset = data.SST()
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
                              collate_fn=data.SST.batcher,
                              shuffle=False,
                              num_workers=0)
    #testset = data.SST(mode='test')
    #test_loader = DataLoader(dataset=testset,
    #                         batch_size=100,
    #                         collate_fn=data.SST.batcher,
    #                         shuffle=False,
    #                         num_workers=0)

    model = TreeLSTM(trainset.num_vocabs,
                     args.x_size,
                     args.h_size,
                     trainset.num_classes,
                     args.dropout)
    if cuda:
        model.cuda()
62
63
64
        zero_initializer = lambda shape : th.zeros(shape).cuda()
    else:
        zero_initializer = th.zeros
65
66
67
68
69
70
71
72
73
    print(model)
    optimizer = optim.Adagrad(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
    dur = []
    for epoch in range(args.epochs):
        t_epoch = time.time()
        for step, batch in enumerate(train_loader):
            g = batch.graph
Minjie Wang's avatar
Minjie Wang committed
74
75
            if cuda:
                batch = _batch_to_cuda(batch)
76
77
78
79
80

            if step >= 3:
                t0 = time.time()
            # traverse graph
            giter = list(tensor_topo_traverse(g, False, args))
81
            logits = model(batch, zero_initializer, iterator=giter, train=True)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            logp = F.log_softmax(logits, 1)
            loss = F.nll_loss(logp, batch.label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step >= 3:
                dur.append(time.time() - t0)

            if step > 0 and step % args.log_every == 0:
                pred = th.argmax(logits, 1)
                acc = th.sum(th.eq(batch.label, pred))
                mean_dur = np.mean(dur)
                print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
                      "Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
                    epoch, step, loss.item(), acc.item()/len(batch.label),
                    mean_dur, args.batch_size / mean_dur))
        print("Epoch time(s):", time.time() - t_epoch)

        # test
        #for step, batch in enumerate(test_loader):
        #    g = batch.graph
        #    n = g.number_of_nodes()
        #    x = th.zeros((n, args.x_size))
        #    h = th.zeros((n, args.h_size))
        #    c = th.zeros((n, args.h_size))
        #    logits = model(batch, x, h, c, train=True)
        #    pred = th.argmax(logits, 1)
        #    acc = th.sum(th.eq(batch.label, pred)) / len(batch.label)
        #    print(acc.item())

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=-1)
    parser.add_argument('--batch-size', type=int, default=25)
    parser.add_argument('--x-size', type=int, default=256)
    parser.add_argument('--h-size', type=int, default=256)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--log-every', type=int, default=5)
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--n-ary', type=int, default=2)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--dropout', type=float, default=0.5)
    args = parser.parse_args()
    main(args)