Unverified Commit 3fa8d755 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Example] Remove duplicated edges in RRN Sudoku example (#1946)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* remove duplicate edges

* Update README.md
parent f9bde91f
# Recurrent Relational Network (RRN)
* Paper link: https://arxiv.org/abs/1711.08028
* Author's code repo: https://github.com/rasmusbergpalm/recurrent-relational-networks.git
* Author's code repo: https://github.com/rasmusbergpalm/recurrent-relational-networks
## Dependencies
* PyTorch 1.0+
* DGL 0.3+
* DGL 0.5+
## Codes
......
......@@ -20,25 +20,24 @@ def _basic_sudoku_graph():
[54, 55, 56, 63, 64, 65, 72, 73, 74],
[57, 58, 59, 66, 67, 68, 75, 76, 77],
[60, 61, 62, 69, 70, 71, 78, 79, 80]]
g = dgl.DGLGraph()
g.add_nodes(81)
edges = set()
for i in range(81):
row, col = i // 9, i % 9
# same row and col
row_src = row * 9
col_src = col
for _ in range(9):
if row_src != i:
g.add_edges(row_src, i)
if col_src != i:
g.add_edges(col_src, i)
edges.add((row_src, i))
edges.add((col_src, i))
row_src += 1
col_src += 9
# same grid
grid_row, grid_col = row // 3, col // 3
for n in grids[grid_row*3 + grid_col]:
if n != i:
g.add_edges(n, i)
edges.add((n, i))
edges = list(edges)
g = dgl.graph(edges)
return g
......@@ -83,6 +82,7 @@ def _get_sudoku_dataset(segment='train'):
return encoded
data = encode(data)
print(f'Number of puzzles in {segment} set : {len(data)}')
return data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment