Commit c37076df authored by Xiagkun Hu's avatar Xiagkun Hu Committed by Mufei Li
Browse files

[Model] Gated Graph Neural Network for bAbI tasks (#880)

* rrn model and sudoku

* add README

* refine the code, add doc strings

* add sudoku solver

* add example for sudoku_solver

* ggnn example

* Rewrite README file

* fix typos
parent 8d848655
# Gated Graph Neural Network (GGNN)
- Paper link: https://arxiv.org/pdf/1511.05493.pdf
## Dependencies
- PyTorch 1.0+
- DGL 0.3.1+
## GGNN implemented in dgl
In dgl, GGNN is implemented as module `GatedGraphConv`, it can be imported as follows:
```python
from dgl.nn.pytorch import GatedGraphConv
```
## Solving bAbI tasks
In this example, we use GGNN to solve some of the [bAbI](https://github.com/facebook/bAbI-tasks)
tasks solved in the paper.
#### Overview of bAbI tasks
bAbI is a set of question answering tasks that require a system to do multi-step reasoning.
Datasets of bAbI tasks are generated by templates, which can be natural language or symbolic
form. In this example, we follow the paper to generate the datasets using symbolic form.
There are 20 tasks in bAbI, in this example, we follow the paper to do task 4, 15, 16, 18 and 19.
#### Task 4: Two argument relations: subject vs. object
An example of task 4 is as follows
```
1 C e A
2 A e B
3 eval A w C
```
A, B, C are nodes; e, w are edges, there are totally four kinds of edges: `n, s, w, e`, which can
be viewed as north, south, west, east.
The first two lines are conditions, and the third line are the question and answer.
So the explanation of the example is:
```
1 Go east from C, we can reach A
2 Go east from A, we can reach B
3 Question: where can we reach if we go west from A? Answer: C
```
If we represent the conditions using a graph, we can view this task as a `Node Selection` task.
For different edges in questions, we view them as different question types, we train
separate models for each question type. The module for solving node selection tasks is
implemented in `ggnn_ns.py`.
For four question types `n, s, w, e`, we assign a question id for them ranging from 0 to 3.
For each question id, run the following commands for training and testing:
```bash
python train_ns.py --task_id=4 --question_id=0 --train_num=50 --epochs=10
python train_ns.py --task_id=4 --question_id=1 --train_num=50 --epochs=10
python train_ns.py --task_id=4 --question_id=2 --train_num=50 --epochs=10
python train_ns.py --task_id=4 --question_id=3 --train_num=50 --epochs=10
```
The training file name `train_ns` means training node selection. `train_num` means the number of
training examples used.
#### Task 15: Basic deduction
Task 15 is similar to task 4, it's also a Node Selection task. An example is shown below:
```
1 I has_fear C
2 H is C
3 G is I
4 A is B
5 E has_fear C
6 C has_fear I
7 B has_fear C
8 F is E
9 eval H has_fear I
```
There are two types of edges in this task: `is, has_fear`. There is only one question type in
this task: `has_fear`, we assign question id `1` for it.
Run the following command for training and testing:
```bash
python train_ns.py --task_id=15 --question_id=1 --train_num=50 --epochs=15 --lr=1e-2
```
#### Task 16: Basic induction
Task 16 is similar to task 15. An example of task 16 is shown below
```
1 J has_color F
2 K has_color I
3 A has_color I
4 G is D
5 J is C
6 H has_color I
7 H is D
8 A is D
9 K is D
10 eval G has_color I
```
There are two types of edges in this task: `is, has_color`. There is only one question type in
this task: `has_color`, we assign question id `1` for it.
Run the following command for training and testing:
```bash
python train_ns.py --task_id=16 --question_id=1 --train_num=50 --epochs=20 --lr=1e-2
```
#### Task 18: Reasoning about size
Task 18 is a `Graph Classification` task, an example is shown below:
```
1 G > B
2 G > D
3 E > F
4 E > A
5 B > A
6 E > B
7 eval G < A false
```
Line 1 to line 6 give some conditions for comparision of the size of entities, line 7 is the
question, asking whether `G < A` is `true` or `false`. So the input is a graph, the output is a
binary classification result. We view it as a `Graph Classification` task.
Following the paper, we use GGNN to encode the graph, followed by a `GlobalAttentionPooling`
layer to pool the graph into a hidden vector, which is used to classify the graph.
The module for solving graph classification tasks is implemented in `ggnn_gc.py`.
There are two types of edges in this task: `>, <`, and so are the question types. We assign
question ids `0, 1` to them.
Run the following commands for training and testing:
```bash
python train_gc.py --task_id=18 --question_id=0 --train_num=50 --batch_size=10 --lr=1e-3 --epochs=20
python train_gc.py --task_id=18 --question_id=1 --train_num=50 --batch_size=10 --lr=1e-3 --epochs=20
```
#### Task 19: Path finding
An example of task 19 is as follows:
```
1 D n A
2 D s E
3 G w D
4 E s B
5 eval path G A w,n
```
Similar to task 4, there are four types of edges: `n, s, w, e`, which can
be viewed as north, south, west, east. The conditions are the same as task 4, the question in
line 5 means `Question: find a path from G to A. Answer: first go west, then go north`. The
output is a sequence of edges. So there is no question type in this task.
The paper uses *Gated Graph Sequence Neural Networks (GGS-NNs)* to solve this kind of problems.
In this example, we implemented GGS-NNs in `ggsnn.py`, run the following command for training
and testing:
```bash
python train_path_finding.py --train_num=250 --epochs=200
```
#### Results
Following the paper, we use 10 different test sets for evaluation. The result is the mean and
standard deviation of the evaluation performance across the 10 datasets. Numbers in the parentheses
are the number of training data used.
| Task ID | Reported <br> Accuracy | DGL <br> Accuracy |
|:---------:|-----------------------------|------------------------------|
| 4 | 100.0 ± 0.00 (50) | 100.0 ± 0.00 (50)|
| 15 | 100.0 ± 0.00 (50) | 100.0 ± 0.00 (50)|
| 16 | 100.0 ± 0.00 (50) | 100.0 ± 0.00 (50)|
| 18 | 100.0 ± 0.00 (50) | 100.0 ± 0.00 (50)|
| 19 | 99.0 ± 1.1 (250) | 97.8 ± 0.02 (50) |
\ No newline at end of file
"""
Data utils for processing bAbI datasets
"""
import os
from torch.utils.data import DataLoader
import dgl
import torch
import string
from dgl.data.utils import download, get_download_dir, _get_dgl_url, extract_archive
def get_babi_dataloaders(batch_size, train_size=50, task_id=4, q_type=0):
_download_babi_data()
node_dict = dict(zip(list(string.ascii_uppercase), range(len(string.ascii_uppercase))))
if task_id == 4:
edge_dict = {'n': 0, 's': 1, 'w': 2, 'e': 3}
reverse_edge = {}
return _ns_dataloader(train_size, q_type, batch_size, node_dict, edge_dict, reverse_edge, '04')
elif task_id == 15:
edge_dict = {'is': 0, 'has_fear': 1}
reverse_edge = {}
return _ns_dataloader(train_size, q_type, batch_size, node_dict, edge_dict, reverse_edge, '15')
elif task_id == 16:
edge_dict = {'is': 0, 'has_color': 1}
reverse_edge = {0: 0}
return _ns_dataloader(train_size, q_type, batch_size, node_dict, edge_dict, reverse_edge, '16')
elif task_id == 18:
edge_dict = {'>': 0, '<': 1}
label_dict = {'false': 0, 'true': 1}
reverse_edge = {0: 1, 1: 0}
return _gc_dataloader(train_size, q_type, batch_size, node_dict, edge_dict, label_dict, reverse_edge, '18')
elif task_id == 19:
edge_dict = {'n': 0, 's': 1, 'w': 2, 'e': 3, '<end>': 4}
reverse_edge = {0: 1, 1: 0, 2: 3, 3: 2}
max_seq_length = 2
return _path_finding_dataloader(train_size, batch_size, node_dict, edge_dict, reverse_edge, '19', max_seq_length)
def _ns_dataloader(train_size, q_type, batch_size, node_dict, edge_dict, reverse_edge, path):
def _collate_fn(batch):
graphs = []
labels = []
for d in batch:
edges = d['edges']
node_ids = []
for s, e, t in edges:
if s not in node_ids:
node_ids.append(s)
if t not in node_ids:
node_ids.append(t)
g = dgl.DGLGraph()
g.add_nodes(len(node_ids))
g.ndata['node_id'] = torch.tensor(node_ids, dtype=torch.long)
nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))
# convert label to node index
label = d['eval'][2]
label_idx = nid2idx[label]
labels.append(label_idx)
edge_types = []
for s, e, t in edges:
g.add_edge(nid2idx[s], nid2idx[t])
edge_types.append(e)
if e in reverse_edge:
g.add_edge(nid2idx[t], nid2idx[s])
edge_types.append(reverse_edge[e])
g.edata['type'] = torch.tensor(edge_types, dtype=torch.long)
annotation = torch.zeros(len(node_ids), dtype=torch.long)
annotation[nid2idx[d['eval'][0]]] = 1
g.ndata['annotation'] = annotation.unsqueeze(-1)
graphs.append(g)
batch_graph = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return batch_graph, labels
def _get_dataloader(data, shuffle):
return DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, collate_fn=_collate_fn)
train_set, dev_set, test_sets = _convert_ns_dataset(train_size, node_dict, edge_dict, path, q_type)
train_dataloader = _get_dataloader(train_set, True)
dev_dataloader = _get_dataloader(dev_set, False)
test_dataloaders = []
for d in test_sets:
dl = _get_dataloader(d, False)
test_dataloaders.append(dl)
return train_dataloader, dev_dataloader, test_dataloaders
def _convert_ns_dataset(train_size, node_dict, edge_dict, path, q_type):
total_num = 11000
def convert(file):
dataset = []
d = dict()
with open(file, 'r') as f:
for i, line in enumerate(f.readlines()):
line = line.strip().split()
if line[0] == '1' and len(d) > 0:
d = dict()
if line[1] == 'eval':
# (src, edge, label)
d['eval'] = (node_dict[line[2]], edge_dict[line[3]], node_dict[line[4]])
if d['eval'][1] == q_type:
dataset.append(d)
if len(dataset) >= total_num:
break
else:
if 'edges' not in d:
d['edges'] = []
d['edges'].append((node_dict[line[1]], edge_dict[line[2]], node_dict[line[3]]))
return dataset
download_dir = get_download_dir()
filename = os.path.join(download_dir, 'babi_data', path, 'data.txt')
data = convert(filename)
assert len(data) == total_num
train_set = data[:train_size]
dev_set = data[950:1000]
test_sets = []
for i in range(10):
test = data[1000 * (i + 1): 1000 * (i + 2)]
test_sets.append(test)
return train_set, dev_set, test_sets
def _gc_dataloader(train_size, q_type, batch_size, node_dict, edge_dict, label_dict, reverse_edge, path):
def _collate_fn(batch):
graphs = []
labels = []
for d in batch:
edges = d['edges']
node_ids = []
for s, e, t in edges:
if s not in node_ids:
node_ids.append(s)
if t not in node_ids:
node_ids.append(t)
g = dgl.DGLGraph()
g.add_nodes(len(node_ids))
g.ndata['node_id'] = torch.tensor(node_ids, dtype=torch.long)
nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))
labels.append(d['eval'][-1])
edge_types = []
for s, e, t in edges:
g.add_edge(nid2idx[s], nid2idx[t])
edge_types.append(e)
if e in reverse_edge:
g.add_edge(nid2idx[t], nid2idx[s])
edge_types.append(reverse_edge[e])
g.edata['type'] = torch.tensor(edge_types, dtype=torch.long)
annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)
annotation[nid2idx[d['eval'][0]]][0] = 1
annotation[nid2idx[d['eval'][2]]][1] = 1
g.ndata['annotation'] = annotation
graphs.append(g)
batch_graph = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return batch_graph, labels
def _get_dataloader(data, shuffle):
return DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, collate_fn=_collate_fn)
train_set, dev_set, test_sets = _convert_gc_dataset(train_size, node_dict, edge_dict, label_dict, path, q_type)
train_dataloader = _get_dataloader(train_set, True)
dev_dataloader = _get_dataloader(dev_set, False)
test_dataloaders = []
for d in test_sets:
dl = _get_dataloader(d, False)
test_dataloaders.append(dl)
return train_dataloader, dev_dataloader, test_dataloaders
def _convert_gc_dataset(train_size, node_dict, edge_dict, label_dict, path, q_type):
total_num = 11000
def convert(file):
dataset = []
d = dict()
with open(file, 'r') as f:
for i, line in enumerate(f.readlines()):
line = line.strip().split()
if line[0] == '1' and len(d) > 0:
d = dict()
if line[1] == 'eval':
# (src, edge, label)
if 'eval' not in d:
d['eval'] = (node_dict[line[2]], edge_dict[line[3]], node_dict[line[4]], label_dict[line[5]])
if d['eval'][1] == q_type:
dataset.append(d)
if len(dataset) >= total_num:
break
else:
if 'edges' not in d:
d['edges'] = []
d['edges'].append((node_dict[line[1]], edge_dict[line[2]], node_dict[line[3]]))
return dataset
download_dir = get_download_dir()
filename = os.path.join(download_dir, 'babi_data', path, 'data.txt')
data = convert(filename)
assert len(data) == total_num
train_set = data[:train_size]
dev_set = data[950:1000]
test_sets = []
for i in range(10):
test = data[1000 * (i + 1): 1000 * (i + 2)]
test_sets.append(test)
return train_set, dev_set, test_sets
def _path_finding_dataloader(train_size, batch_size, node_dict, edge_dict, reverse_edge, path, max_seq_length):
def _collate_fn(batch):
graphs = []
ground_truths = []
seq_lengths = []
for d in batch:
edges = d['edges']
node_ids = []
for s, e, t in edges:
if s not in node_ids:
node_ids.append(s)
if t not in node_ids:
node_ids.append(t)
g = dgl.DGLGraph()
g.add_nodes(len(node_ids))
g.ndata['node_id'] = torch.tensor(node_ids, dtype=torch.long)
nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))
truth = d['seq_out'] + [edge_dict['<end>']] * (max_seq_length - len(d['seq_out']))
seq_len = len(d['seq_out'])
ground_truths.append(truth)
seq_lengths.append(seq_len)
edge_types = []
for s, e, t in edges:
g.add_edge(nid2idx[s], nid2idx[t])
edge_types.append(e)
if e in reverse_edge:
g.add_edge(nid2idx[t], nid2idx[s])
edge_types.append(reverse_edge[e])
g.edata['type'] = torch.tensor(edge_types, dtype=torch.long)
annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)
annotation[nid2idx[d['eval'][0]]][0] = 1
annotation[nid2idx[d['eval'][1]]][1] = 1
g.ndata['annotation'] = annotation
graphs.append(g)
batch_graph = dgl.batch(graphs)
ground_truths = torch.tensor(ground_truths, dtype=torch.long)
seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)
return batch_graph, ground_truths, seq_lengths
def _get_dataloader(data, shuffle):
return DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, collate_fn=_collate_fn)
train_set, dev_set, test_sets = _convert_path_finding(train_size, node_dict, edge_dict, path)
train_dataloader = _get_dataloader(train_set, True)
dev_dataloader = _get_dataloader(dev_set, False)
test_dataloaders = []
for d in test_sets:
dl = _get_dataloader(d, False)
test_dataloaders.append(dl)
return train_dataloader, dev_dataloader, test_dataloaders
def _convert_path_finding(train_size, node_dict, edge_dict, path):
total_num = 11000
def convert(file):
dataset = []
d = dict()
with open(file, 'r') as f:
for line in f.readlines():
line = line.strip().split()
if line[0] == '1' and len(d) > 0:
d = dict()
if line[1] == 'eval':
# (src, edge, label)
d['eval'] = (node_dict[line[3]], node_dict[line[4]])
d['seq_out'] = []
seq_out = line[5].split(',')
for e in seq_out:
d['seq_out'].append(edge_dict[e])
dataset.append(d)
if len(dataset) >= total_num:
break
else:
if 'edges' not in d:
d['edges'] = []
d['edges'].append((node_dict[line[1]], edge_dict[line[2]], node_dict[line[3]]))
return dataset
download_dir = get_download_dir()
filename = os.path.join(download_dir, 'babi_data', path, 'data.txt')
data = convert(filename)
assert len(data) == total_num
train_set = data[:train_size]
dev_set = data[950:1000]
test_sets = []
for i in range(10):
test = data[1000 * (i + 1): 1000 * (i + 2)]
test_sets.append(test)
return train_set, dev_set, test_sets
def _download_babi_data():
download_dir = get_download_dir()
zip_file_path = os.path.join(download_dir, 'babi_data.zip')
data_url = _get_dgl_url('models/ggnn_babi_data.zip')
download(data_url, path=zip_file_path)
extract_dir = os.path.join(download_dir, 'babi_data')
if not os.path.exists(extract_dir):
extract_archive(zip_file_path, extract_dir)
\ No newline at end of file
"""
Gated Graph Neural Network module for graph classification tasks
"""
from dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling
import torch
from torch import nn
class GraphClsGGNN(nn.Module):
def __init__(self,
annotation_size,
out_feats,
n_steps,
n_etypes,
num_cls):
super(GraphClsGGNN, self).__init__()
self.annotation_size = annotation_size
self.out_feats = out_feats
self.ggnn = GatedGraphConv(in_feats=out_feats,
out_feats=out_feats,
n_steps=n_steps,
n_etypes=n_etypes)
pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)
self.pooling = GlobalAttentionPooling(pooling_gate_nn)
self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, graph, labels=None):
etypes = graph.edata.pop('type')
annotation = graph.ndata.pop('annotation').float()
assert annotation.size()[-1] == self.annotation_size
node_num = graph.number_of_nodes()
zero_pad = torch.zeros([node_num, self.out_feats - self.annotation_size],
dtype=torch.float,
device=annotation.device)
h1 = torch.cat([annotation, zero_pad], -1)
out = self.ggnn(graph, h1, etypes)
out = torch.cat([out, annotation], -1)
out = self.pooling(graph, out)
logits = self.output_layer(out)
preds = torch.argmax(logits, -1)
if labels is not None:
loss = self.loss_fn(logits, labels)
return loss, preds
return preds
\ No newline at end of file
"""
Gated Graph Neural Network module for node selection tasks
"""
from dgl.nn.pytorch import GatedGraphConv
import torch
from torch import nn
import dgl
class NodeSelectionGGNN(nn.Module):
def __init__(self,
annotation_size,
out_feats,
n_steps,
n_etypes):
super(NodeSelectionGGNN, self).__init__()
self.annotation_size = annotation_size
self.out_feats = out_feats
self.ggnn = GatedGraphConv(in_feats=out_feats,
out_feats=out_feats,
n_steps=n_steps,
n_etypes=n_etypes)
self.output_layer = nn.Linear(annotation_size + out_feats, 1)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, graph, labels=None):
etypes = graph.edata.pop('type')
annotation = graph.ndata.pop('annotation').float()
assert annotation.size()[-1] == self.annotation_size
node_num = graph.number_of_nodes()
zero_pad = torch.zeros([node_num, self.out_feats - self.annotation_size],
dtype=torch.float,
device=annotation.device)
h1 = torch.cat([annotation, zero_pad], -1)
out = self.ggnn(graph, h1, etypes)
all_logits = self.output_layer(torch.cat([out, annotation], -1)).squeeze(-1)
graph.ndata['logits'] = all_logits
batch_g = dgl.unbatch(graph)
preds = []
if labels is not None:
loss = 0.0
for i, g in enumerate(batch_g):
logits = g.ndata['logits']
preds.append(torch.argmax(logits))
if labels is not None:
logits = logits.unsqueeze(0)
y = labels[i].unsqueeze(0)
loss += self.loss_fn(logits, y)
if labels is not None:
loss /= float(len(batch_g))
return loss, preds
return preds
\ No newline at end of file
"""
Gated Graph Sequence Neural Network for sequence outputs
"""
from dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling
import torch
from torch import nn
import torch.nn.functional as F
class GGSNN(nn.Module):
def __init__(self,
annotation_size,
out_feats,
n_steps,
n_etypes,
max_seq_length,
num_cls):
super(GGSNN, self).__init__()
self.annotation_size = annotation_size
self.out_feats = out_feats
self.max_seq_length = max_seq_length
self.ggnn = GatedGraphConv(in_feats=out_feats,
out_feats=out_feats,
n_steps=n_steps,
n_etypes=n_etypes)
self.annotation_out_layer = nn.Linear(annotation_size + out_feats, annotation_size)
pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)
self.pooling = GlobalAttentionPooling(pooling_gate_nn)
self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
def forward(self, graph, seq_lengths, ground_truth=None):
etypes = graph.edata.pop('type')
annotation = graph.ndata.pop('annotation').float()
assert annotation.size()[-1] == self.annotation_size
node_num = graph.number_of_nodes()
all_logits = []
for _ in range(self.max_seq_length):
zero_pad = torch.zeros([node_num, self.out_feats - self.annotation_size],
dtype=torch.float,
device=annotation.device)
h1 = torch.cat([annotation.detach(), zero_pad], -1)
out = self.ggnn(graph, h1, etypes)
out = torch.cat([out, annotation], -1)
logits = self.pooling(graph, out)
logits = self.output_layer(logits)
all_logits.append(logits)
annotation = self.annotation_out_layer(out)
annotation = F.softmax(annotation, -1)
all_logits = torch.stack(all_logits, 1)
preds = torch.argmax(all_logits, -1)
if ground_truth is not None:
loss = sequence_loss(all_logits, ground_truth, seq_lengths)
return loss, preds
return preds
def sequence_loss(logits, ground_truth, seq_length=None):
def sequence_mask(length):
max_length = logits.size(1)
batch_size = logits.size(0)
range_tensor = torch.arange(0, max_length, dtype=seq_length.dtype, device=seq_length.device)
range_tensor = torch.stack([range_tensor]*batch_size, 0)
expanded_length = torch.stack([length]*max_length, -1)
mask = (range_tensor < expanded_length).float()
return mask
loss = nn.CrossEntropyLoss(reduction='none')(logits.permute((0, 2, 1)), ground_truth)
if seq_length is None:
loss = loss.mean()
else:
mask = sequence_mask(seq_length)
loss = (loss * mask).sum(-1) / seq_length.float()
loss = loss.mean()
return loss
\ No newline at end of file
"""
Training and testing for graph classification tasks in bAbI
"""
import argparse
from data_utils import get_babi_dataloaders
from ggnn_gc import GraphClsGGNN
from torch.optim import Adam
import torch
import numpy as np
def main(args):
out_feats = {18: 3}
n_etypes = {18: 2}
train_dataloader, dev_dataloader, test_dataloaders = \
get_babi_dataloaders(batch_size=args.batch_size,
train_size=args.train_num,
task_id=args.task_id,
q_type=args.question_id)
model = GraphClsGGNN(annotation_size=2,
out_feats=out_feats[args.task_id],
n_steps=5,
n_etypes=n_etypes[args.task_id],
num_cls=2)
opt = Adam(model.parameters(), lr=args.lr)
print(f'Task {args.task_id}, question_id {args.question_id}')
print(f'Training set size: {len(train_dataloader.dataset)}')
print(f'Dev set size: {len(dev_dataloader.dataset)}')
# training and dev stage
for epoch in range(args.epochs):
model.train()
for i, batch in enumerate(train_dataloader):
g, labels = batch
loss, _ = model(g, labels)
opt.zero_grad()
loss.backward()
opt.step()
if epoch % 20 == 0:
print(f'Epoch {epoch}, batch {i} loss: {loss.data}')
if epoch % 20 != 0:
continue
dev_preds = []
dev_labels = []
model.eval()
for g, labels in dev_dataloader:
with torch.no_grad():
preds = model(g)
preds = preds.data.numpy().tolist()
labels = labels.data.numpy().tolist()
dev_preds += preds
dev_labels += labels
acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
print(f"Epoch {epoch}, Dev acc {acc}")
# test stage
for i, dataloader in enumerate(test_dataloaders):
print(f'Test set {i} size: {len(dataloader.dataset)}')
test_acc_list = []
for dataloader in test_dataloaders:
test_preds = []
test_labels = []
model.eval()
for g, labels in dataloader:
with torch.no_grad():
preds = model(g)
preds = preds.data.numpy().tolist()
labels = labels.data.numpy().tolist()
test_preds += preds
test_labels += labels
acc = np.equal(test_labels, test_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
test_acc_list.append(acc)
test_acc_mean = np.mean(test_acc_list)
test_acc_std = np.std(test_acc_list)
print(f'Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Gated Graph Neural Networks for graph classification tasks in bAbI')
parser.add_argument('--task_id', type=int, default=18,
help='task id from 1 to 20')
parser.add_argument('--question_id', type=int, default=0,
help='question id for each task')
parser.add_argument('--train_num', type=int, default=950,
help='Number of training examples')
parser.add_argument('--batch_size', type=int, default=50,
help='batch size')
parser.add_argument('--lr', type=float, default=1e-3,
help='learning rate')
parser.add_argument('--epochs', type=int, default=200,
help='number of training epochs')
args = parser.parse_args()
main(args)
\ No newline at end of file
"""
Training and testing for node selection tasks in bAbI
"""
import argparse
from data_utils import get_babi_dataloaders
from ggnn_ns import NodeSelectionGGNN
from torch.optim import Adam
import torch
import numpy as np
import time
def main(args):
out_feats = {4: 4, 15: 5, 16: 6}
n_etypes = {4: 4, 15: 2, 16: 2}
train_dataloader, dev_dataloader, test_dataloaders = \
get_babi_dataloaders(batch_size=args.batch_size,
train_size=args.train_num,
task_id=args.task_id,
q_type=args.question_id)
model = NodeSelectionGGNN(annotation_size=1,
out_feats=out_feats[args.task_id],
n_steps=5,
n_etypes=n_etypes[args.task_id])
opt = Adam(model.parameters(), lr=args.lr)
print(f'Task {args.task_id}, question_id {args.question_id}')
print(f'Training set size: {len(train_dataloader.dataset)}')
print(f'Dev set size: {len(dev_dataloader.dataset)}')
# training and dev stage
for epoch in range(args.epochs):
model.train()
for i, batch in enumerate(train_dataloader):
g, labels = batch
loss, _ = model(g, labels)
opt.zero_grad()
loss.backward()
opt.step()
print(f'Epoch {epoch}, batch {i} loss: {loss.data}')
dev_preds = []
dev_labels = []
model.eval()
for g, labels in dev_dataloader:
with torch.no_grad():
preds = model(g)
preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
labels = labels.data.numpy().tolist()
dev_preds += preds
dev_labels += labels
acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
print(f"Epoch {epoch}, Dev acc {acc}")
# test stage
for i, dataloader in enumerate(test_dataloaders):
print(f'Test set {i} size: {len(dataloader.dataset)}')
test_acc_list = []
for dataloader in test_dataloaders:
test_preds = []
test_labels = []
model.eval()
for g, labels in dataloader:
with torch.no_grad():
preds = model(g)
preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
labels = labels.data.numpy().tolist()
test_preds += preds
test_labels += labels
acc = np.equal(test_labels, test_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
test_acc_list.append(acc)
test_acc_mean = np.mean(test_acc_list)
test_acc_std = np.std(test_acc_list)
print(f'Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Gated Graph Neural Networks for node selection tasks in bAbI')
parser.add_argument('--task_id', type=int, default=16,
help='task id from 1 to 20')
parser.add_argument('--question_id', type=int, default=1,
help='question id for each task')
parser.add_argument('--train_num', type=int, default=50,
help='Number of training examples')
parser.add_argument('--batch_size', type=int, default=10,
help='batch size')
parser.add_argument('--lr', type=float, default=1e-3,
help='learning rate')
parser.add_argument('--epochs', type=int, default=100,
help='number of training epochs')
args = parser.parse_args()
main(args)
\ No newline at end of file
"""
Training and testing for sequence output tasks in bAbI.
Here we take task 19 'Path Finding' as an example
"""
import argparse
from data_utils import get_babi_dataloaders
from ggsnn import GGSNN
from torch.optim import Adam
import torch
import numpy as np
def main(args):
out_feats = {19: 6}
n_etypes = {19: 4}
train_dataloader, dev_dataloader, test_dataloaders = \
get_babi_dataloaders(batch_size=args.batch_size,
train_size=args.train_num,
task_id=args.task_id,
q_type=-1)
model = GGSNN(annotation_size=2,
out_feats=out_feats[args.task_id],
n_steps=5,
n_etypes=n_etypes[args.task_id],
max_seq_length=2,
num_cls=5)
opt = Adam(model.parameters(), lr=args.lr)
print(f'Task {args.task_id}')
print(f'Training set size: {len(train_dataloader.dataset)}')
print(f'Dev set size: {len(dev_dataloader.dataset)}')
# training and dev stage
for epoch in range(args.epochs):
model.train()
for i, batch in enumerate(train_dataloader):
g, ground_truths, seq_lengths = batch
loss, _ = model(g, seq_lengths, ground_truths)
opt.zero_grad()
loss.backward()
opt.step()
if epoch % 20 == 0:
print(f'Epoch {epoch}, batch {i} loss: {loss.data}')
if epoch % 20 != 0:
continue
dev_res = []
model.eval()
for g, ground_truths, seq_lengths in dev_dataloader:
with torch.no_grad():
preds = model(g, seq_lengths)
preds = preds.data.numpy().tolist()
ground_truths = ground_truths.data.numpy().tolist()
for i, p in enumerate(preds):
if p == ground_truths[i]:
dev_res.append(1.0)
else:
dev_res.append(0.0)
acc = sum(dev_res) / len(dev_res)
print(f"Epoch {epoch}, Dev acc {acc}")
# test stage
for i, dataloader in enumerate(test_dataloaders):
print(f'Test set {i} size: {len(dataloader.dataset)}')
test_acc_list = []
for dataloader in test_dataloaders:
test_res = []
model.eval()
for g, ground_truths, seq_lengths in dataloader:
with torch.no_grad():
preds = model(g, seq_lengths)
preds = preds.data.numpy().tolist()
ground_truths = ground_truths.data.numpy().tolist()
for i, p in enumerate(preds):
if p == ground_truths[i]:
test_res.append(1.0)
else:
test_res.append(0.0)
acc = sum(test_res) / len(test_res)
test_acc_list.append(acc)
test_acc_mean = np.mean(test_acc_list)
test_acc_std = np.std(test_acc_list)
print(f'Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Gated Graph Sequence Neural Networks for sequential output tasks in '
'bAbI')
parser.add_argument('--task_id', type=int, default=19,
help='task id from 1 to 20')
parser.add_argument('--train_num', type=int, default=250,
help='Number of training examples')
parser.add_argument('--batch_size', type=int, default=10,
help='batch size')
parser.add_argument('--lr', type=float, default=1e-3,
help='learning rate')
parser.add_argument('--epochs', type=int, default=200,
help='number of training epochs')
args = parser.parse_args()
main(args)
\ No newline at end of file
......@@ -13,10 +13,41 @@
The folder contains a DGL implementation of Recurrent Relational Network, and its
application on sudoku solving.
## Results
## Usage
Run the following
- To train the RNN for sudoku, run the following
```
python3 train_sudoku.py --output_dir out/ --do_train --do_eval
```
Test accuracy (puzzle-level): 96.08% (paper: 96.6%)
- To use the trained model for solving sudoku, follow the example bellow:
```python
from sudoku_solver import solve_sudoku
q = [[9, 7, 0, 4, 0, 2, 0, 5, 3],
[0, 4, 6, 0, 9, 0, 0, 0, 0],
[0, 0, 8, 6, 0, 1, 4, 0, 7],
[0, 0, 0, 0, 0, 3, 5, 0, 0],
[7, 6, 0, 0, 0, 0, 0, 8, 2],
[0, 0, 2, 8, 0, 0, 0, 0, 0],
[6, 0, 5, 1, 0, 7, 2, 0, 0],
[0, 0, 0, 0, 6, 0, 7, 4, 0],
[4, 3, 0, 2, 0, 9, 0, 6, 1]
]
answer = solve_sudoku(q)
print(answer)
'''
[[9 7 1 4 8 2 6 5 3]
[3 4 6 7 9 5 1 2 8]
[2 5 8 6 3 1 4 9 7]
[8 1 4 9 2 3 5 7 6]
[7 6 3 5 1 4 9 8 2]
[5 9 2 8 7 6 3 1 4]
[6 8 5 1 4 7 2 3 9]
[1 2 9 3 6 8 7 4 5]
[4 3 7 2 5 9 8 6 1]]
'''
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment