"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5c53ca5ed89b30cd7b04cbf01665388e2af9f249"
Commit b2f7f0ee authored by Xiagkun Hu's avatar Xiagkun Hu Committed by Zihao Ye
Browse files

[Model] Recurrent Relational Network (RRN) on sudoku (#733)

* rrn model and sudoku

* add README

* refine the code, add doc strings

* add sudoku solver
parent e1f08644
# Recurrent Relational Network (RRN)
* Paper link: https://arxiv.org/abs/1711.08028
* Author's code repo: https://github.com/rasmusbergpalm/recurrent-relational-networks.git
## Dependencies
* PyTorch 1.0+
* DGL 0.3+
## Codes
The folder contains a DGL implementation of Recurrent Relational Network, and its
application on sudoku solving.
## Results
Run the following
```
python3 train_sudoku.py --output_dir out/ --do_train --do_eval
```
Test accuracy (puzzle-level): 96.08% (paper: 96.6%)
"""
Recurrent Relational Network(RRN) module
References:
- Recurrent Relational Networks
- Paper: https://arxiv.org/abs/1711.08028
- Original Code: https://github.com/rasmusbergpalm/recurrent-relational-networks
"""
import torch
from torch import nn
import dgl.function as fn
class RRNLayer(nn.Module):
def __init__(self, msg_layer, node_update_func, edge_drop):
super(RRNLayer, self).__init__()
self.msg_layer = msg_layer
self.node_update_func = node_update_func
self.edge_dropout = nn.Dropout(edge_drop)
def forward(self, g):
g.apply_edges(self.get_msg)
g.edata['e'] = self.edge_dropout(g.edata['e'])
g.update_all(message_func=fn.copy_e('e', 'msg'),
reduce_func=fn.sum('msg', 'm'))
g.apply_nodes(self.node_update)
def get_msg(self, edges):
e = torch.cat([edges.src['h'], edges.dst['h']], -1)
e = self.msg_layer(e)
return {'e': e}
def node_update(self, nodes):
return self.node_update_func(nodes)
class RRN(nn.Module):
def __init__(self,
msg_layer,
node_update_func,
num_steps,
edge_drop):
super(RRN, self).__init__()
self.num_steps = num_steps
self.rrn_layer = RRNLayer(msg_layer, node_update_func, edge_drop)
def forward(self, g, get_all_outputs=True):
outputs = []
for _ in range(self.num_steps):
self.rrn_layer(g)
if get_all_outputs:
outputs.append(g.ndata['h'])
if get_all_outputs:
outputs = torch.stack(outputs, 0) # num_steps x n_nodes x h_dim
else:
outputs = g.ndata['h'] # n_nodes x h_dim
return outputs
"""
SudokuNN module based on RRN for solving sudoku puzzles
"""
from rrn import RRN
from torch import nn
import torch
class SudokuNN(nn.Module):
def __init__(self,
num_steps,
embed_size=16,
hidden_dim=96,
edge_drop=0.1):
super(SudokuNN, self).__init__()
self.num_steps = num_steps
self.digit_embed = nn.Embedding(10, embed_size)
self.row_embed = nn.Embedding(9, embed_size)
self.col_embed = nn.Embedding(9, embed_size)
self.input_layer = nn.Sequential(
nn.Linear(3*embed_size, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.lstm = nn.LSTMCell(hidden_dim*2, hidden_dim, bias=False)
msg_layer = nn.Sequential(
nn.Linear(2*hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.rrn = RRN(msg_layer, self.node_update_func, num_steps, edge_drop)
self.output_layer = nn.Linear(hidden_dim, 10)
self.loss_func = nn.CrossEntropyLoss()
def forward(self, g, is_training=True):
labels = g.ndata.pop('a')
input_digits = self.digit_embed(g.ndata.pop('q'))
rows = self.row_embed(g.ndata.pop('row'))
cols = self.col_embed(g.ndata.pop('col'))
x = self.input_layer(torch.cat([input_digits, rows, cols], -1))
g.ndata['x'] = x
g.ndata['h'] = x
g.ndata['rnn_h'] = torch.zeros_like(x, dtype=torch.float)
g.ndata['rnn_c'] = torch.zeros_like(x, dtype=torch.float)
outputs = self.rrn(g, is_training)
logits = self.output_layer(outputs)
preds = torch.argmax(logits, -1)
if is_training:
labels = torch.stack([labels]*self.num_steps, 0)
logits = logits.view([-1, 10])
labels = labels.view([-1])
loss = self.loss_func(logits, labels)
return preds, loss
def node_update_func(self, nodes):
x, h, m, c = nodes.data['x'], nodes.data['rnn_h'], nodes.data['m'], nodes.data['rnn_c']
new_h, new_c = self.lstm(torch.cat([x, m], -1), (h, c))
return {'h': new_h, 'rnn_c': new_c, 'rnn_h': new_h}
import csv
import os
import urllib.request
import zipfile
import numpy as np
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset
import torch
import dgl
from copy import copy
def _basic_sudoku_graph():
grids = [[0, 1, 2, 9, 10, 11, 18, 19, 20],
[3, 4, 5, 12, 13, 14, 21, 22, 23],
[6, 7, 8, 15, 16, 17, 24, 25, 26],
[27, 28, 29, 36, 37, 38, 45, 46, 47],
[30, 31, 32, 39, 40, 41, 48, 49, 50],
[33, 34, 35, 42, 43, 44, 51, 52, 53],
[54, 55, 56, 63, 64, 65, 72, 73, 74],
[57, 58, 59, 66, 67, 68, 75, 76, 77],
[60, 61, 62, 69, 70, 71, 78, 79, 80]]
g = dgl.DGLGraph()
g.add_nodes(81)
for i in range(81):
row, col = i // 9, i % 9
# same row and col
row_src = row * 9
col_src = col
for _ in range(9):
if row_src != i:
g.add_edges(row_src, i)
if col_src != i:
g.add_edges(col_src, i)
row_src += 1
col_src += 9
# same grid
grid_row, grid_col = row // 3, col // 3
for n in grids[grid_row*3 + grid_col]:
if n != i:
g.add_edges(n, i)
return g
class ListDataset(Dataset):
def __init__(self, *lists_of_data):
assert all(len(lists_of_data[0]) == len(d) for d in lists_of_data)
self.lists_of_data = lists_of_data
def __getitem__(self, index):
return tuple(d[index] for d in self.lists_of_data)
def __len__(self):
return len(self.lists_of_data[0])
def _get_sudoku_dataset(segment='train'):
assert segment in ['train', 'valid', 'test']
url = "https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/sudoku-hard.zip"
zip_fname = "/tmp/sudoku-hard.zip"
dest_dir = '/tmp/sudoku-hard/'
if not os.path.exists(dest_dir):
print("Downloading data...")
urllib.request.urlretrieve(url, zip_fname)
with zipfile.ZipFile(zip_fname) as f:
f.extractall('/tmp/')
def read_csv(fname):
print("Reading %s..." % fname)
with open(dest_dir + fname) as f:
reader = csv.reader(f, delimiter=',')
return [(q, a) for q, a in reader]
data = read_csv(segment + '.csv')
def encode(samples):
def parse(x):
return list(map(int, list(x)))
encoded = [(parse(q), parse(a)) for q, a in samples]
return encoded
data = encode(data)
return data
def sudoku_dataloader(batch_size, segment='train'):
"""
Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns
a DGLGraph instance, the ndata of the graph contains:
'q': question, e.g. the sudoku puzzle to be solved, the position is to be filled with number from 1-9
if the value in the position is 0
'a': answer, the ground truth of the sudoku puzzle
'row': row index for each position in the grid
'col': column index for each position in the grid
:param batch_size: Batch size for the dataloader
:param segment: The segment of the datasets, must in ['train', 'valid', 'test']
:return: A pytorch DataLoader instance
"""
data = _get_sudoku_dataset(segment)
q, a = zip(*data)
dataset = ListDataset(q, a)
if segment == 'train':
data_sampler = RandomSampler(dataset)
else:
data_sampler = SequentialSampler(dataset)
basic_graph = _basic_sudoku_graph()
sudoku_indices = np.arange(0, 81)
rows = sudoku_indices // 9
cols = sudoku_indices % 9
def collate_fn(batch):
graph_list = []
for q, a in batch:
q = torch.tensor(q, dtype=torch.long)
a = torch.tensor(a, dtype=torch.long)
graph = copy(basic_graph)
graph.ndata['q'] = q # q means question
graph.ndata['a'] = a # a means answer
graph.ndata['row'] = torch.tensor(rows, dtype=torch.long)
graph.ndata['col'] = torch.tensor(cols, dtype=torch.long)
graph_list.append(graph)
batch_graph = dgl.batch(graph_list)
return batch_graph
dataloader = DataLoader(dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn)
return dataloader
import os
import urllib.request
import torch
import numpy as np
from sudoku_data import _basic_sudoku_graph
def solve_sudoku(puzzle):
"""
Solve sudoku puzzle using RRN.
:param puzzle: an array-like data with shape [9, 9], blank positions are filled with 0
:return: a [9, 9] shaped numpy array
"""
puzzle = np.array(puzzle, dtype=np.long).reshape([-1])
model_path = 'ckpt'
if not os.path.exists(model_path):
os.mkdir(model_path)
model_filename = os.path.join(model_path, 'rrn-sudoku.pkl')
if not os.path.exists(model_filename):
print('Downloading model...')
url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/models/rrn-sudoku.pkl'
urllib.request.urlretrieve(url, model_filename)
model = torch.load(model_filename, map_location='cpu')
g = _basic_sudoku_graph()
sudoku_indices = np.arange(0, 81)
rows = sudoku_indices // 9
cols = sudoku_indices % 9
g.ndata['row'] = torch.tensor(rows, dtype=torch.long)
g.ndata['col'] = torch.tensor(cols, dtype=torch.long)
g.ndata['q'] = torch.tensor(puzzle, dtype=torch.long)
g.ndata['a'] = torch.tensor(puzzle, dtype=torch.long)
pred, _ = model(g, False)
pred = pred.cpu().data.numpy().reshape([9, 9])
return pred
if __name__ == '__main__':
q = [
[9, 7, 0, 4, 0, 2, 0, 5, 3],
[0, 4, 6, 0, 9, 0, 0, 0, 0],
[0, 0, 8, 6, 0, 1, 4, 0, 7],
[0, 0, 0, 0, 0, 3, 5, 0, 0],
[7, 6, 0, 0, 0, 0, 0, 8, 2],
[0, 0, 2, 8, 0, 0, 0, 0, 0],
[6, 0, 5, 1, 0, 7, 2, 0, 0],
[0, 0, 0, 0, 6, 0, 7, 4, 0],
[4, 3, 0, 2, 0, 9, 0, 6, 1]
]
answer = solve_sudoku(q)
print(answer)
from sudoku_data import sudoku_dataloader
import argparse
from sudoku import SudokuNN
import torch
from torch.optim import Adam
import os
import numpy as np
def main(args):
if args.gpu < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.gpu)
if args.do_train:
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop).to(device)
train_dataloader = sudoku_dataloader(args.batch_size, segment='train')
dev_dataloader = sudoku_dataloader(args.batch_size, segment='valid')
opt = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
best_dev_acc = 0.0
for epoch in range(args.epochs):
model.train()
for i, g in enumerate(train_dataloader):
g.ndata['q'] = g.ndata['q'].to(device)
g.ndata['a'] = g.ndata['a'].to(device)
g.ndata['row'] = g.ndata['row'].to(device)
g.ndata['col'] = g.ndata['col'].to(device)
_, loss = model(g)
opt.zero_grad()
loss.backward()
opt.step()
if i % 100 == 0:
print(f"Epoch {epoch}, batch {i}, loss {loss.cpu().data}")
# dev
print("\n=========Dev step========")
model.eval()
dev_loss = []
dev_res = []
for g in dev_dataloader:
g.ndata['q'] = g.ndata['q'].to(device)
g.ndata['a'] = g.ndata['a'].to(device)
g.ndata['row'] = g.ndata['row'].to(device)
g.ndata['col'] = g.ndata['col'].to(device)
target = g.ndata['a']
target = target.view([-1, 81])
with torch.no_grad():
preds, loss = model(g, is_training=False)
preds = preds.view([-1, 81])
for i in range(preds.size(0)):
dev_res.append(int(torch.equal(preds[i, :], target[i, :])))
dev_loss.append(loss.cpu().detach().data)
dev_acc = sum(dev_res) / len(dev_res)
print(f"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}")
if dev_acc >= best_dev_acc:
torch.save(model, os.path.join(args.output_dir, 'model_best.bin'))
best_dev_acc = dev_acc
print(f"Best dev accuracy {best_dev_acc}\n")
torch.save(model, os.path.join(args.output_dir, 'model_final.bin'))
if args.do_eval:
model_path = os.path.join(args.output_dir, 'model_best.bin')
if not os.path.exists(model_path):
raise FileNotFoundError("Saved model not Found!")
model = torch.load(model_path).to(device)
test_dataloader = sudoku_dataloader(args.batch_size, segment='test')
print("\n=========Test step========")
model.eval()
test_loss = []
test_res = []
for g in test_dataloader:
g.ndata['q'] = g.ndata['q'].to(device)
g.ndata['a'] = g.ndata['a'].to(device)
g.ndata['row'] = g.ndata['row'].to(device)
g.ndata['col'] = g.ndata['col'].to(device)
target = g.ndata['a']
target = target.view([-1, 81])
with torch.no_grad():
preds, loss = model(g, is_training=False)
preds = preds
preds = preds.view([-1, 81])
for i in range(preds.size(0)):
test_res.append(int(torch.equal(preds[i, :], target[i, :])))
test_loss.append(loss.cpu().detach().data)
test_acc = sum(test_res) / len(test_res)
print(f"Test loss {np.mean(test_loss)}, accuracy {test_acc}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Recurrent Relational Network on sudoku task.')
parser.add_argument("--output_dir", type=str, default=None, required=True,
help="The directory to save model")
parser.add_argument("--do_train", default=False, action="store_true",
help="Train the model")
parser.add_argument("--do_eval", default=False, action="store_true",
help="Evaluate the model on test data")
parser.add_argument("--epochs", type=int, default=100,
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=64,
help="Batch size")
parser.add_argument("--edge_drop", type=float, default=0.4,
help="Dropout rate at edges.")
parser.add_argument("--steps", type=int, default=32,
help="Number of message passing steps.")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=2e-4,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-4,
help="weight decay (L2 penalty)")
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment