sudoku_data.py 4.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import csv
import os
import urllib.request
import zipfile
import numpy as np
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset
import torch
import dgl
from copy import copy


def _basic_sudoku_graph():
    grids = [[0, 1, 2, 9, 10, 11, 18, 19, 20],
             [3, 4, 5, 12, 13, 14, 21, 22, 23],
             [6, 7, 8, 15, 16, 17, 24, 25, 26],
             [27, 28, 29, 36, 37, 38, 45, 46, 47],
             [30, 31, 32, 39, 40, 41, 48, 49, 50],
             [33, 34, 35, 42, 43, 44, 51, 52, 53],
             [54, 55, 56, 63, 64, 65, 72, 73, 74],
             [57, 58, 59, 66, 67, 68, 75, 76, 77],
             [60, 61, 62, 69, 70, 71, 78, 79, 80]]
    g = dgl.DGLGraph()
    g.add_nodes(81)
    for i in range(81):
        row, col = i // 9, i % 9
        # same row and col
        row_src = row * 9
        col_src = col
        for _ in range(9):
            if row_src != i:
                g.add_edges(row_src, i)
            if col_src != i:
                g.add_edges(col_src, i)
            row_src += 1
            col_src += 9
        # same grid
        grid_row, grid_col = row // 3, col // 3
        for n in grids[grid_row*3 + grid_col]:
            if n != i:
                g.add_edges(n, i)
    return g


class ListDataset(Dataset):
    def __init__(self, *lists_of_data):
        assert all(len(lists_of_data[0]) == len(d) for d in lists_of_data)
        self.lists_of_data = lists_of_data

    def __getitem__(self, index):
        return tuple(d[index] for d in self.lists_of_data)

    def __len__(self):
        return len(self.lists_of_data[0])


def _get_sudoku_dataset(segment='train'):
    assert segment in ['train', 'valid', 'test']
Jinjing Zhou's avatar
Jinjing Zhou committed
59
    url = "https://data.dgl.ai/dataset/sudoku-hard.zip"
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    zip_fname = "/tmp/sudoku-hard.zip"
    dest_dir = '/tmp/sudoku-hard/'

    if not os.path.exists(dest_dir):
        print("Downloading data...")

        urllib.request.urlretrieve(url, zip_fname)
        with zipfile.ZipFile(zip_fname) as f:
            f.extractall('/tmp/')

    def read_csv(fname):
        print("Reading %s..." % fname)
        with open(dest_dir + fname) as f:
            reader = csv.reader(f, delimiter=',')
            return [(q, a) for q, a in reader]

    data = read_csv(segment + '.csv')

    def encode(samples):
        def parse(x):
            return list(map(int, list(x)))

        encoded = [(parse(q), parse(a)) for q, a in samples]
        return encoded

    data = encode(data)

    return data


def sudoku_dataloader(batch_size, segment='train'):
    """
    Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns
    a DGLGraph instance, the ndata of the graph contains:
    'q': question, e.g. the sudoku puzzle to be solved, the position is to be filled with number from 1-9
         if the value in the position is 0
    'a': answer, the ground truth of the sudoku puzzle
    'row': row index for each position in the grid
    'col': column index for each position in the grid
    :param batch_size: Batch size for the dataloader
    :param segment: The segment of the datasets, must in ['train', 'valid', 'test']
    :return: A pytorch DataLoader instance
    """
    data = _get_sudoku_dataset(segment)
    q, a = zip(*data)

    dataset = ListDataset(q, a)
    if segment == 'train':
        data_sampler = RandomSampler(dataset)
    else:
        data_sampler = SequentialSampler(dataset)

    basic_graph = _basic_sudoku_graph()
    sudoku_indices = np.arange(0, 81)
    rows = sudoku_indices // 9
    cols = sudoku_indices % 9

    def collate_fn(batch):
        graph_list = []
        for q, a in batch:
            q = torch.tensor(q, dtype=torch.long)
            a = torch.tensor(a, dtype=torch.long)
            graph = copy(basic_graph)
            graph.ndata['q'] = q  # q means question
            graph.ndata['a'] = a  # a means answer
            graph.ndata['row'] = torch.tensor(rows, dtype=torch.long)
            graph.ndata['col'] = torch.tensor(cols, dtype=torch.long)
            graph_list.append(graph)
        batch_graph = dgl.batch(graph_list)
        return batch_graph

    dataloader = DataLoader(dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn)
    return dataloader