test_train.py 4.58 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
alexeib's avatar
alexeib committed
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
alexeib's avatar
alexeib committed
5

Myle Ott's avatar
Myle Ott committed
6
7
import contextlib
from io import StringIO
alexeib's avatar
alexeib committed
8
9
10
import unittest
from unittest.mock import MagicMock, patch

Myle Ott's avatar
Myle Ott committed
11
12
import torch

13
from fairseq import data, checkpoint_utils
alexeib's avatar
alexeib committed
14
15


Myle Ott's avatar
Myle Ott committed
16
def mock_trainer(epoch, num_updates, iterations_in_epoch):
alexeib's avatar
alexeib committed
17
    trainer = MagicMock()
Myle Ott's avatar
Myle Ott committed
18
19
20
21
22
23
24
    trainer.load_checkpoint.return_value = {
        'train_iterator': {
            'epoch': epoch,
            'iterations_in_epoch': iterations_in_epoch,
            'shuffle': False,
        },
    }
alexeib's avatar
alexeib committed
25
26
27
28
    trainer.get_num_updates.return_value = num_updates
    return trainer


Myle Ott's avatar
Myle Ott committed
29
30
31
32
33
34
35
36
37
def mock_dict():
    d = MagicMock()
    d.pad.return_value = 1
    d.eos.return_value = 2
    d.unk.return_value = 3
    return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
Myle Ott's avatar
Myle Ott committed
38
39
40
41
    tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
    tokens_ds = data.TokenBlockDataset(
        tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
    )
Myle Ott's avatar
Myle Ott committed
42
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
43
    dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
Myle Ott's avatar
Myle Ott committed
44
    epoch_itr = data.EpochBatchIterator(
45
46
        dataset=dataset,
        collate_fn=dataset.collater,
47
        batch_sampler=[[i] for i in range(epoch_size)],
Myle Ott's avatar
Myle Ott committed
48
49
    )
    return trainer, epoch_itr
alexeib's avatar
alexeib committed
50
51
52
53
54


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
alexeib's avatar
alexeib committed
55
56
        self.args_mock = MagicMock()
        self.args_mock.optimizer_overrides = '{}'
Myle Ott's avatar
Myle Ott committed
57
58
59
        self.args_mock.reset_dataloader = False
        self.args_mock.reset_meters = False
        self.args_mock.reset_optimizer = False
alexeib's avatar
alexeib committed
60
61
62
63
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
64
            'os.path.isabs': MagicMock(return_value=False),
alexeib's avatar
alexeib committed
65
66
67
68
69
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
Naman Goyal's avatar
Naman Goyal committed
70

Myle Ott's avatar
Myle Ott committed
71
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
72
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
Myle Ott's avatar
Myle Ott committed
73
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
Myle Ott's avatar
Myle Ott committed
74

Myle Ott's avatar
Myle Ott committed
75
            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
Naman Goyal's avatar
Naman Goyal committed
76

Myle Ott's avatar
Myle Ott committed
77
78
79
80
81
82
83
84
85
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
            self.assertEqual(epoch_itr.iterations_in_epoch, 51)
alexeib's avatar
alexeib committed
86

87
88
89
90
91
92
93
94
95
96
97
98
            for _ in range(150 - 52):
                next(itr)
            self.assertEqual(epoch_itr.iterations_in_epoch, 149)
            self.assertTrue(itr.has_next())
            next(itr)
            self.assertFalse(itr.has_next())

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertTrue(itr.has_next())
            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)

alexeib's avatar
alexeib committed
99
    def test_load_full_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
100
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
101
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
Myle Ott's avatar
Myle Ott committed
102
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
Myle Ott's avatar
Myle Ott committed
103

Myle Ott's avatar
Myle Ott committed
104
            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
Myle Ott's avatar
Myle Ott committed
105
106
107
108
109
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
110
111

    def test_load_no_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
112
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
113
            trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
Myle Ott's avatar
Myle Ott committed
114
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
Myle Ott's avatar
Myle Ott committed
115
116
            self.patches['os.path.isfile'].return_value = False

Myle Ott's avatar
Myle Ott committed
117
            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
Myle Ott's avatar
Myle Ott committed
118
119
120
121
122
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 1)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
123
124
125
126
127
128
129

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()