Unverified Commit 912da18c authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Example] Update results in Sudoku example with RRN (#1998)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* remove duplicate edges

* Update README.md

* change model save and load

* Update README.md

* Update README.md
parent d1cf5c38
...@@ -17,12 +17,24 @@ application on sudoku solving. ...@@ -17,12 +17,24 @@ application on sudoku solving.
- To train the RRN for sudoku, run the following - To train the RRN for sudoku, run the following
``` ```
python3 train_sudoku.py --output_dir out/ --do_train --do_eval python3 train_sudoku.py --output_dir out/ --do_train
``` ```
Test accuracy (puzzle-level): 96.08% (paper: 96.6%)
- Test with specified aggregation steps:
```
python3 train_sudoku.py --output_dir out/ --do_eval --steps 64
```
Test accuracy (puzzle-level):
| | 32 steps | 64 steps |
| ----- | :------: | :------: |
| Paper | 94.1 | 96.6 |
| DGL | 95.3 | 98.9 |
- To use the trained model for solving sudoku, follow the example bellow: - To use the trained model for solving sudoku, follow the example bellow:
```python ```python
from sudoku_solver import solve_sudoku from sudoku_solver import solve_sudoku
......
...@@ -4,6 +4,7 @@ import urllib.request ...@@ -4,6 +4,7 @@ import urllib.request
import torch import torch
import numpy as np import numpy as np
from sudoku_data import _basic_sudoku_graph from sudoku_data import _basic_sudoku_graph
from sudoku import SudokuNN
def solve_sudoku(puzzle): def solve_sudoku(puzzle):
...@@ -23,7 +24,9 @@ def solve_sudoku(puzzle): ...@@ -23,7 +24,9 @@ def solve_sudoku(puzzle):
url = 'https://data.dgl.ai/models/rrn-sudoku.pkl' url = 'https://data.dgl.ai/models/rrn-sudoku.pkl'
urllib.request.urlretrieve(url, model_filename) urllib.request.urlretrieve(url, model_filename)
model = torch.load(model_filename, map_location='cpu') model = SudokuNN(num_steps=64, edge_drop=0.)
model.load_state_dict(torch.load(model_filename, map_location='cpu'))
model.eval()
g = _basic_sudoku_graph() g = _basic_sudoku_graph()
sudoku_indices = np.arange(0, 81) sudoku_indices = np.arange(0, 81)
......
...@@ -13,11 +13,12 @@ def main(args): ...@@ -13,11 +13,12 @@ def main(args):
else: else:
device = torch.device('cuda', args.gpu) device = torch.device('cuda', args.gpu)
model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop)
if args.do_train: if args.do_train:
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir) os.mkdir(args.output_dir)
model.to(device)
model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop).to(device)
train_dataloader = sudoku_dataloader(args.batch_size, segment='train') train_dataloader = sudoku_dataloader(args.batch_size, segment='train')
dev_dataloader = sudoku_dataloader(args.batch_size, segment='valid') dev_dataloader = sudoku_dataloader(args.batch_size, segment='valid')
...@@ -57,18 +58,20 @@ def main(args): ...@@ -57,18 +58,20 @@ def main(args):
dev_acc = sum(dev_res) / len(dev_res) dev_acc = sum(dev_res) / len(dev_res)
print(f"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}") print(f"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}")
if dev_acc >= best_dev_acc: if dev_acc >= best_dev_acc:
torch.save(model, os.path.join(args.output_dir, 'model_best.bin')) torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_best.bin'))
best_dev_acc = dev_acc best_dev_acc = dev_acc
print(f"Best dev accuracy {best_dev_acc}\n") print(f"Best dev accuracy {best_dev_acc}\n")
torch.save(model, os.path.join(args.output_dir, 'model_final.bin')) torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_final.bin'))
if args.do_eval: if args.do_eval:
model_path = os.path.join(args.output_dir, 'model_best.bin') model_path = os.path.join(args.output_dir, 'model_best.bin')
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError("Saved model not Found!") raise FileNotFoundError("Saved model not Found!")
model = torch.load(model_path).to(device) model.load_state_dict(torch.load(model_path))
model.to(device)
test_dataloader = sudoku_dataloader(args.batch_size, segment='test') test_dataloader = sudoku_dataloader(args.batch_size, segment='test')
print("\n=========Test step========") print("\n=========Test step========")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment