test.py 1.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import torch as th
import torch.optim as optim
import nx_SST
import tree_lstm

parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--h-size', type=int, default=512)
parser.add_argument('--log-every', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--n-iterations', type=int, default=1000)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--x-size', type=int, default=256)
args = parser.parse_args()

if args.gpu < 0:
    cuda = False
else:
    cuda = True
    th.cuda.set_device(args.gpu)

reader = nx_SST.nx_BCT_Reader(cuda)
loader = reader.generator()

network = tree_lstm.NAryTreeLSTM(len(reader.vocab) + 1,
                                 args.x_size, args.h_size, args.n_ary, 5)
if cuda:
    network.cuda()
adagrad = optim.Adagrad(network.parameters(), args.lr)

for i in range(args.n_iterations):
    nll = 0
    for j in range(args.batch_size):
        g = next(loader)
        nll += network(g, train=True)
    nll /= args.batch_size

    adagrad.zero_grad()
    nll.backward()
    adagrad.step()

    if (i + 1) % args.log_every == 0:
        print('[iteration %d]cross-entropy loss: %f' % ((i + 1), nll))