train.py 4.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import argparse
import time
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import dgl
import dgl.data as data
11
import dgl.ndarray as nd
12
13
14
15
16
17

from tree_lstm import TreeLSTM

def tensor_topo_traverse(g, cuda, args):
    n = g.number_of_nodes()
    if cuda:
18
        adjmat = g._graph.adjacency_matrix().get(th.device('cuda:{}'.format(cuda)))
19
20
        mask = th.ones((n, 1)).cuda()
    else:
21
        adjmat = g._graph.adjacency_matrix().get(th.device('cpu'))
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        mask = th.ones((n, 1))
    degree = th.spmm(adjmat, mask)
    while th.sum(mask) != 0.:
        v = (degree == 0.).float()
        v = v * mask
        mask = mask - v
        frontier = th.squeeze(th.squeeze(v).nonzero(), 1)
        yield frontier
        degree -= th.spmm(adjmat, v)

def main(args):
    cuda = args.gpu >= 0
    if cuda:
        th.cuda.set_device(args.gpu)
36
37
38
    def _batcher(trees):
        bg = dgl.batch(trees)
        if cuda:
39
40
            for key in bg.node_attr_schemes().keys():
                bg.ndata[key] = bg.ndata[key].cuda()
41
        return bg
42
43
44
    trainset = data.SST()
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
45
                              collate_fn=_batcher,
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                              shuffle=False,
                              num_workers=0)
    #testset = data.SST(mode='test')
    #test_loader = DataLoader(dataset=testset,
    #                         batch_size=100,
    #                         collate_fn=data.SST.batcher,
    #                         shuffle=False,
    #                         num_workers=0)

    model = TreeLSTM(trainset.num_vocabs,
                     args.x_size,
                     args.h_size,
                     trainset.num_classes,
                     args.dropout)
    if cuda:
        model.cuda()
62
63
64
        zero_initializer = lambda shape : th.zeros(shape).cuda()
    else:
        zero_initializer = th.zeros
65
66
67
68
69
70
71
    print(model)
    optimizer = optim.Adagrad(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
    dur = []
    for epoch in range(args.epochs):
        t_epoch = time.time()
72
        for step, graph in enumerate(train_loader):
73
74
            if step >= 3:
                t0 = time.time()
75
            label = graph.ndata.pop('y')
76
            # traverse graph
77
78
            giter = list(tensor_topo_traverse(graph, False, args))
            logits = model(graph, zero_initializer, iterator=giter, train=True)
79
            logp = F.log_softmax(logits, 1)
80
            loss = F.nll_loss(logp, label)
81
82
83
84
85
86
87
88
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step >= 3:
                dur.append(time.time() - t0)

            if step > 0 and step % args.log_every == 0:
                pred = th.argmax(logits, 1)
89
                acc = th.sum(th.eq(label, pred))
90
91
92
                mean_dur = np.mean(dur)
                print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
                      "Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
Minjie Wang's avatar
Minjie Wang committed
93
                    epoch, step, loss.item(), acc.item() / len(label),
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
                    mean_dur, args.batch_size / mean_dur))
        print("Epoch time(s):", time.time() - t_epoch)

        # test
        #for step, batch in enumerate(test_loader):
        #    g = batch.graph
        #    n = g.number_of_nodes()
        #    x = th.zeros((n, args.x_size))
        #    h = th.zeros((n, args.h_size))
        #    c = th.zeros((n, args.h_size))
        #    logits = model(batch, x, h, c, train=True)
        #    pred = th.argmax(logits, 1)
        #    acc = th.sum(th.eq(batch.label, pred)) / len(batch.label)
        #    print(acc.item())

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=-1)
    parser.add_argument('--batch-size', type=int, default=25)
    parser.add_argument('--x-size', type=int, default=256)
    parser.add_argument('--h-size', type=int, default=256)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--log-every', type=int, default=5)
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--n-ary', type=int, default=2)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--dropout', type=float, default=0.5)
    args = parser.parse_args()
    main(args)