""" Training and testing for node selection tasks in bAbI """ import argparse from data_utils import get_babi_dataloaders from ggnn_ns import NodeSelectionGGNN from torch.optim import Adam import torch import numpy as np import time def main(args): out_feats = {4: 4, 15: 5, 16: 6} n_etypes = {4: 4, 15: 2, 16: 2} train_dataloader, dev_dataloader, test_dataloaders = \ get_babi_dataloaders(batch_size=args.batch_size, train_size=args.train_num, task_id=args.task_id, q_type=args.question_id) model = NodeSelectionGGNN(annotation_size=1, out_feats=out_feats[args.task_id], n_steps=5, n_etypes=n_etypes[args.task_id]) opt = Adam(model.parameters(), lr=args.lr) print(f'Task {args.task_id}, question_id {args.question_id}') print(f'Training set size: {len(train_dataloader.dataset)}') print(f'Dev set size: {len(dev_dataloader.dataset)}') # training and dev stage for epoch in range(args.epochs): model.train() for i, batch in enumerate(train_dataloader): g, labels = batch loss, _ = model(g, labels) opt.zero_grad() loss.backward() opt.step() print(f'Epoch {epoch}, batch {i} loss: {loss.data}') dev_preds = [] dev_labels = [] model.eval() for g, labels in dev_dataloader: with torch.no_grad(): preds = model(g) preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist() labels = labels.data.numpy().tolist() dev_preds += preds dev_labels += labels acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist() acc = sum(acc) / len(acc) print(f"Epoch {epoch}, Dev acc {acc}") # test stage for i, dataloader in enumerate(test_dataloaders): print(f'Test set {i} size: {len(dataloader.dataset)}') test_acc_list = [] for dataloader in test_dataloaders: test_preds = [] test_labels = [] model.eval() for g, labels in dataloader: with torch.no_grad(): preds = model(g) preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist() labels = labels.data.numpy().tolist() test_preds += preds test_labels += labels acc = np.equal(test_labels, test_preds).astype(np.float).tolist() acc = sum(acc) / len(acc) test_acc_list.append(acc) test_acc_mean = np.mean(test_acc_list) test_acc_std = np.std(test_acc_list) print(f'Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Gated Graph Neural Networks for node selection tasks in bAbI') parser.add_argument('--task_id', type=int, default=16, help='task id from 1 to 20') parser.add_argument('--question_id', type=int, default=1, help='question id for each task') parser.add_argument('--train_num', type=int, default=50, help='Number of training examples') parser.add_argument('--batch_size', type=int, default=10, help='batch size') parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') parser.add_argument('--epochs', type=int, default=100, help='number of training epochs') args = parser.parse_args() main(args)