sudoku.py 2.46 KB
Newer Older
1
2
3
4
"""
SudokuNN module based on RRN for solving sudoku puzzles
"""

5
import torch
6
7
8
9
10
from rrn import RRN
from torch import nn


class SudokuNN(nn.Module):
11
    def __init__(self, num_steps, embed_size=16, hidden_dim=96, edge_drop=0.1):
12
13
14
15
16
17
18
19
        super(SudokuNN, self).__init__()
        self.num_steps = num_steps

        self.digit_embed = nn.Embedding(10, embed_size)
        self.row_embed = nn.Embedding(9, embed_size)
        self.col_embed = nn.Embedding(9, embed_size)

        self.input_layer = nn.Sequential(
20
            nn.Linear(3 * embed_size, hidden_dim),
21
22
23
24
25
26
27
28
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

29
        self.lstm = nn.LSTMCell(hidden_dim * 2, hidden_dim, bias=False)
30
31

        msg_layer = nn.Sequential(
32
            nn.Linear(2 * hidden_dim, hidden_dim),
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.rrn = RRN(msg_layer, self.node_update_func, num_steps, edge_drop)

        self.output_layer = nn.Linear(hidden_dim, 10)

        self.loss_func = nn.CrossEntropyLoss()

    def forward(self, g, is_training=True):
48
        labels = g.ndata.pop("a")
49

50
51
52
        input_digits = self.digit_embed(g.ndata.pop("q"))
        rows = self.row_embed(g.ndata.pop("row"))
        cols = self.col_embed(g.ndata.pop("col"))
53
54

        x = self.input_layer(torch.cat([input_digits, rows, cols], -1))
55
56
57
58
        g.ndata["x"] = x
        g.ndata["h"] = x
        g.ndata["rnn_h"] = torch.zeros_like(x, dtype=torch.float)
        g.ndata["rnn_c"] = torch.zeros_like(x, dtype=torch.float)
59
60
61
62
63
64
65

        outputs = self.rrn(g, is_training)
        logits = self.output_layer(outputs)

        preds = torch.argmax(logits, -1)

        if is_training:
66
            labels = torch.stack([labels] * self.num_steps, 0)
67
68
69
70
71
72
        logits = logits.view([-1, 10])
        labels = labels.view([-1])
        loss = self.loss_func(logits, labels)
        return preds, loss

    def node_update_func(self, nodes):
73
74
75
76
77
78
        x, h, m, c = (
            nodes.data["x"],
            nodes.data["rnn_h"],
            nodes.data["m"],
            nodes.data["rnn_c"],
        )
79
        new_h, new_c = self.lstm(torch.cat([x, m], -1), (h, c))
80
        return {"h": new_h, "rnn_c": new_c, "rnn_h": new_h}