"launcher/vscode:/vscode.git/clone" did not exist on "2d0a7173d4891e7cd5f9b77f8e0987b82a339e51"
train_path_finding.py 3.61 KB
Newer Older
1
2
3
4
5
6
"""
Training and testing for sequence output tasks in bAbI.
Here we take task 19 'Path Finding' as an example
"""

import argparse
7
8
9

import numpy as np
import torch
10
11
12
13
14
15
16
17
18
from data_utils import get_babi_dataloaders
from ggsnn import GGSNN
from torch.optim import Adam


def main(args):
    out_feats = {19: 6}
    n_etypes = {19: 4}

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    train_dataloader, dev_dataloader, test_dataloaders = get_babi_dataloaders(
        batch_size=args.batch_size,
        train_size=args.train_num,
        task_id=args.task_id,
        q_type=-1,
    )

    model = GGSNN(
        annotation_size=2,
        out_feats=out_feats[args.task_id],
        n_steps=5,
        n_etypes=n_etypes[args.task_id],
        max_seq_length=2,
        num_cls=5,
    )
34
35
    opt = Adam(model.parameters(), lr=args.lr)

36
    print(f"Task {args.task_id}")
37

38
39
    print(f"Training set size: {len(train_dataloader.dataset)}")
    print(f"Dev set size: {len(dev_dataloader.dataset)}")
40
41
42
43
44
45
46
47
48
49
50

    # training and dev stage
    for epoch in range(args.epochs):
        model.train()
        for i, batch in enumerate(train_dataloader):
            g, ground_truths, seq_lengths = batch
            loss, _ = model(g, seq_lengths, ground_truths)
            opt.zero_grad()
            loss.backward()
            opt.step()
            if epoch % 20 == 0:
51
                print(f"Epoch {epoch}, batch {i} loss: {loss.data}")
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

        if epoch % 20 != 0:
            continue
        dev_res = []
        model.eval()
        for g, ground_truths, seq_lengths in dev_dataloader:
            with torch.no_grad():
                preds = model(g, seq_lengths)
                preds = preds.data.numpy().tolist()
                ground_truths = ground_truths.data.numpy().tolist()
                for i, p in enumerate(preds):
                    if p == ground_truths[i]:
                        dev_res.append(1.0)
                    else:
                        dev_res.append(0.0)
        acc = sum(dev_res) / len(dev_res)
        print(f"Epoch {epoch}, Dev acc {acc}")

    # test stage
    for i, dataloader in enumerate(test_dataloaders):
72
        print(f"Test set {i} size: {len(dataloader.dataset)}")
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    test_acc_list = []
    for dataloader in test_dataloaders:
        test_res = []
        model.eval()
        for g, ground_truths, seq_lengths in dataloader:
            with torch.no_grad():
                preds = model(g, seq_lengths)
                preds = preds.data.numpy().tolist()
                ground_truths = ground_truths.data.numpy().tolist()
                for i, p in enumerate(preds):
                    if p == ground_truths[i]:
                        test_res.append(1.0)
                    else:
                        test_res.append(0.0)
        acc = sum(test_res) / len(test_res)
        test_acc_list.append(acc)

    test_acc_mean = np.mean(test_acc_list)
    test_acc_std = np.std(test_acc_list)

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    print(
        f"Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Gated Graph Sequence Neural Networks for sequential output tasks in "
        "bAbI"
    )
    parser.add_argument(
        "--task_id", type=int, default=19, help="task id from 1 to 20"
    )
    parser.add_argument(
        "--train_num", type=int, default=250, help="Number of training examples"
    )
    parser.add_argument("--batch_size", type=int, default=10, help="batch size")
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument(
        "--epochs", type=int, default=200, help="number of training epochs"
    )
115
116
117

    args = parser.parse_args()

118
    main(args)