train_sudoku.py 4.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from sudoku_data import sudoku_dataloader
import argparse
from sudoku import SudokuNN
import torch
from torch.optim import Adam
import os
import numpy as np


def main(args):
    if args.gpu < 0 or not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        device = torch.device('cuda', args.gpu)

16
17
    model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop)

18
19
20
    if args.do_train:
        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
21
        model.to(device)
22
23
24
25
26
27
28
29
30
        train_dataloader = sudoku_dataloader(args.batch_size, segment='train')
        dev_dataloader = sudoku_dataloader(args.batch_size, segment='valid')

        opt = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        best_dev_acc = 0.0
        for epoch in range(args.epochs):
            model.train()
            for i, g in enumerate(train_dataloader):
31
                g = g.to(device)
32
33
34
35
36
37
38
39
40
41
42
43
44
                _, loss = model(g)
                opt.zero_grad()
                loss.backward()
                opt.step()
                if i % 100 == 0:
                    print(f"Epoch {epoch}, batch {i}, loss {loss.cpu().data}")

            # dev
            print("\n=========Dev step========")
            model.eval()
            dev_loss = []
            dev_res = []
            for g in dev_dataloader:
45
                g = g.to(device)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
                target = g.ndata['a']
                target = target.view([-1, 81])

                with torch.no_grad():
                    preds, loss = model(g, is_training=False)
                    preds = preds.view([-1, 81])

                    for i in range(preds.size(0)):
                        dev_res.append(int(torch.equal(preds[i, :], target[i, :])))

                    dev_loss.append(loss.cpu().detach().data)

            dev_acc = sum(dev_res) / len(dev_res)
            print(f"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}")
            if dev_acc >= best_dev_acc:
61
                torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_best.bin'))
62
63
64
                best_dev_acc = dev_acc
            print(f"Best dev accuracy {best_dev_acc}\n")

65
        torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_final.bin'))
66
67
68
69
70
71

    if args.do_eval:
        model_path = os.path.join(args.output_dir, 'model_best.bin')
        if not os.path.exists(model_path):
            raise FileNotFoundError("Saved model not Found!")

72
73
74
        model.load_state_dict(torch.load(model_path))
        model.to(device)
        
75
76
77
78
79
80
81
        test_dataloader = sudoku_dataloader(args.batch_size, segment='test')

        print("\n=========Test step========")
        model.eval()
        test_loss = []
        test_res = []
        for g in test_dataloader:
82
            g = g.to(device)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            target = g.ndata['a']
            target = target.view([-1, 81])

            with torch.no_grad():
                preds, loss = model(g, is_training=False)
                preds = preds
                preds = preds.view([-1, 81])

                for i in range(preds.size(0)):
                    test_res.append(int(torch.equal(preds[i, :], target[i, :])))

                test_loss.append(loss.cpu().detach().data)

        test_acc = sum(test_res) / len(test_res)
        print(f"Test loss {np.mean(test_loss)}, accuracy {test_acc}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Recurrent Relational Network on sudoku task.')
    parser.add_argument("--output_dir", type=str, default=None, required=True,
                        help="The directory to save model")
    parser.add_argument("--do_train", default=False, action="store_true",
                        help="Train the model")
    parser.add_argument("--do_eval", default=False, action="store_true",
                        help="Evaluate the model on test data")
    parser.add_argument("--epochs", type=int, default=100,
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch size")
    parser.add_argument("--edge_drop", type=float, default=0.4,
                        help="Dropout rate at edges.")
    parser.add_argument("--steps", type=int, default=32,
                        help="Number of message passing steps.")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--lr", type=float, default=2e-4,
                        help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-4,
                        help="weight decay (L2 penalty)")

    args = parser.parse_args()

    main(args)