sudoku_data.py 4.21 KB
Newer Older
1
2
3
4
import csv
import os
import urllib.request
import zipfile
5
6
from copy import copy

7
import numpy as np
8
import torch
9
10
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset
11

12
13
14
15
import dgl


def _basic_sudoku_graph():
16
17
18
19
20
21
22
23
24
25
26
    grids = [
        [0, 1, 2, 9, 10, 11, 18, 19, 20],
        [3, 4, 5, 12, 13, 14, 21, 22, 23],
        [6, 7, 8, 15, 16, 17, 24, 25, 26],
        [27, 28, 29, 36, 37, 38, 45, 46, 47],
        [30, 31, 32, 39, 40, 41, 48, 49, 50],
        [33, 34, 35, 42, 43, 44, 51, 52, 53],
        [54, 55, 56, 63, 64, 65, 72, 73, 74],
        [57, 58, 59, 66, 67, 68, 75, 76, 77],
        [60, 61, 62, 69, 70, 71, 78, 79, 80],
    ]
27
    edges = set()
28
29
30
31
32
33
    for i in range(81):
        row, col = i // 9, i % 9
        # same row and col
        row_src = row * 9
        col_src = col
        for _ in range(9):
34
35
            edges.add((row_src, i))
            edges.add((col_src, i))
36
37
38
39
            row_src += 1
            col_src += 9
        # same grid
        grid_row, grid_col = row // 3, col // 3
40
        for n in grids[grid_row * 3 + grid_col]:
41
            if n != i:
42
43
44
                edges.add((n, i))
    edges = list(edges)
    g = dgl.graph(edges)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    return g


class ListDataset(Dataset):
    def __init__(self, *lists_of_data):
        assert all(len(lists_of_data[0]) == len(d) for d in lists_of_data)
        self.lists_of_data = lists_of_data

    def __getitem__(self, index):
        return tuple(d[index] for d in self.lists_of_data)

    def __len__(self):
        return len(self.lists_of_data[0])


60
61
def _get_sudoku_dataset(segment="train"):
    assert segment in ["train", "valid", "test"]
Jinjing Zhou's avatar
Jinjing Zhou committed
62
    url = "https://data.dgl.ai/dataset/sudoku-hard.zip"
63
    zip_fname = "/tmp/sudoku-hard.zip"
64
    dest_dir = "/tmp/sudoku-hard/"
65
66
67
68
69
70

    if not os.path.exists(dest_dir):
        print("Downloading data...")

        urllib.request.urlretrieve(url, zip_fname)
        with zipfile.ZipFile(zip_fname) as f:
71
            f.extractall("/tmp/")
72
73
74
75

    def read_csv(fname):
        print("Reading %s..." % fname)
        with open(dest_dir + fname) as f:
76
            reader = csv.reader(f, delimiter=",")
77
78
            return [(q, a) for q, a in reader]

79
    data = read_csv(segment + ".csv")
80
81
82
83
84
85
86
87
88

    def encode(samples):
        def parse(x):
            return list(map(int, list(x)))

        encoded = [(parse(q), parse(a)) for q, a in samples]
        return encoded

    data = encode(data)
89
    print(f"Number of puzzles in {segment} set : {len(data)}")
90
91
92
93

    return data


94
def sudoku_dataloader(batch_size, segment="train"):
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    """
    Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns
    a DGLGraph instance, the ndata of the graph contains:
    'q': question, e.g. the sudoku puzzle to be solved, the position is to be filled with number from 1-9
         if the value in the position is 0
    'a': answer, the ground truth of the sudoku puzzle
    'row': row index for each position in the grid
    'col': column index for each position in the grid
    :param batch_size: Batch size for the dataloader
    :param segment: The segment of the datasets, must in ['train', 'valid', 'test']
    :return: A pytorch DataLoader instance
    """
    data = _get_sudoku_dataset(segment)
    q, a = zip(*data)

    dataset = ListDataset(q, a)
111
    if segment == "train":
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        data_sampler = RandomSampler(dataset)
    else:
        data_sampler = SequentialSampler(dataset)

    basic_graph = _basic_sudoku_graph()
    sudoku_indices = np.arange(0, 81)
    rows = sudoku_indices // 9
    cols = sudoku_indices % 9

    def collate_fn(batch):
        graph_list = []
        for q, a in batch:
            q = torch.tensor(q, dtype=torch.long)
            a = torch.tensor(a, dtype=torch.long)
            graph = copy(basic_graph)
127
128
129
130
            graph.ndata["q"] = q  # q means question
            graph.ndata["a"] = a  # a means answer
            graph.ndata["row"] = torch.tensor(rows, dtype=torch.long)
            graph.ndata["col"] = torch.tensor(cols, dtype=torch.long)
131
132
133
134
            graph_list.append(graph)
        batch_graph = dgl.batch(graph_list)
        return batch_graph

135
136
137
    dataloader = DataLoader(
        dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn
    )
138
    return dataloader