train_ns.py 3.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Training and testing for node selection tasks in bAbI
"""

import argparse
from data_utils import get_babi_dataloaders
from ggnn_ns import NodeSelectionGGNN
from torch.optim import Adam
import torch
import numpy as np
import time


def main(args):
    out_feats = {4: 4, 15: 5, 16: 6}
    n_etypes = {4: 4, 15: 2, 16: 2}

    train_dataloader, dev_dataloader, test_dataloaders = \
        get_babi_dataloaders(batch_size=args.batch_size,
                             train_size=args.train_num,
                             task_id=args.task_id,
                             q_type=args.question_id)

    model = NodeSelectionGGNN(annotation_size=1,
                              out_feats=out_feats[args.task_id],
                              n_steps=5,
                              n_etypes=n_etypes[args.task_id])
    opt = Adam(model.parameters(), lr=args.lr)

    print(f'Task {args.task_id}, question_id {args.question_id}')

    print(f'Training set size: {len(train_dataloader.dataset)}')
    print(f'Dev set size: {len(dev_dataloader.dataset)}')

    # training and dev stage
    for epoch in range(args.epochs):
        model.train()
        for i, batch in enumerate(train_dataloader):
            g, labels = batch
            loss, _ = model(g, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            print(f'Epoch {epoch}, batch {i} loss: {loss.data}')

        dev_preds = []
        dev_labels = []
        model.eval()
        for g, labels in dev_dataloader:
            with torch.no_grad():
                preds = model(g)
                preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
                labels = labels.data.numpy().tolist()
                dev_preds += preds
                dev_labels += labels
        acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist()
        acc = sum(acc) / len(acc)
        print(f"Epoch {epoch}, Dev acc {acc}")

    # test stage
    for i, dataloader in enumerate(test_dataloaders):
        print(f'Test set {i} size: {len(dataloader.dataset)}')

    test_acc_list = []
    for dataloader in test_dataloaders:
        test_preds = []
        test_labels = []
        model.eval()
        for g, labels in dataloader:
            with torch.no_grad():
                preds = model(g)
                preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
                labels = labels.data.numpy().tolist()
                test_preds += preds
                test_labels += labels
        acc = np.equal(test_labels, test_preds).astype(np.float).tolist()
        acc = sum(acc) / len(acc)
        test_acc_list.append(acc)

    test_acc_mean = np.mean(test_acc_list)
    test_acc_std = np.std(test_acc_list)

    print(f'Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Gated Graph Neural Networks for node selection tasks in bAbI')
    parser.add_argument('--task_id', type=int, default=16,
                        help='task id from 1 to 20')
    parser.add_argument('--question_id', type=int, default=1,
                        help='question id for each task')
    parser.add_argument('--train_num', type=int, default=50,
                        help='Number of training examples')
    parser.add_argument('--batch_size', type=int, default=10,
                        help='batch size')
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='learning rate')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')

    args = parser.parse_args()

    main(args)