Unverified Commit 3fa8d755 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Example] Remove duplicated edges in RRN Sudoku example (#1946)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* remove duplicate edges

* Update README.md
parent f9bde91f
# Recurrent Relational Network (RRN) # Recurrent Relational Network (RRN)
* Paper link: https://arxiv.org/abs/1711.08028 * Paper link: https://arxiv.org/abs/1711.08028
* Author's code repo: https://github.com/rasmusbergpalm/recurrent-relational-networks.git * Author's code repo: https://github.com/rasmusbergpalm/recurrent-relational-networks
## Dependencies ## Dependencies
* PyTorch 1.0+ * PyTorch 1.0+
* DGL 0.3+ * DGL 0.5+
## Codes ## Codes
......
...@@ -20,25 +20,24 @@ def _basic_sudoku_graph(): ...@@ -20,25 +20,24 @@ def _basic_sudoku_graph():
[54, 55, 56, 63, 64, 65, 72, 73, 74], [54, 55, 56, 63, 64, 65, 72, 73, 74],
[57, 58, 59, 66, 67, 68, 75, 76, 77], [57, 58, 59, 66, 67, 68, 75, 76, 77],
[60, 61, 62, 69, 70, 71, 78, 79, 80]] [60, 61, 62, 69, 70, 71, 78, 79, 80]]
g = dgl.DGLGraph() edges = set()
g.add_nodes(81)
for i in range(81): for i in range(81):
row, col = i // 9, i % 9 row, col = i // 9, i % 9
# same row and col # same row and col
row_src = row * 9 row_src = row * 9
col_src = col col_src = col
for _ in range(9): for _ in range(9):
if row_src != i: edges.add((row_src, i))
g.add_edges(row_src, i) edges.add((col_src, i))
if col_src != i:
g.add_edges(col_src, i)
row_src += 1 row_src += 1
col_src += 9 col_src += 9
# same grid # same grid
grid_row, grid_col = row // 3, col // 3 grid_row, grid_col = row // 3, col // 3
for n in grids[grid_row*3 + grid_col]: for n in grids[grid_row*3 + grid_col]:
if n != i: if n != i:
g.add_edges(n, i) edges.add((n, i))
edges = list(edges)
g = dgl.graph(edges)
return g return g
...@@ -83,6 +82,7 @@ def _get_sudoku_dataset(segment='train'): ...@@ -83,6 +82,7 @@ def _get_sudoku_dataset(segment='train'):
return encoded return encoded
data = encode(data) data = encode(data)
print(f'Number of puzzles in {segment} set : {len(data)}')
return data return data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment