test_train.py 4.55 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Myle Ott's avatar
Myle Ott committed
8
9
import contextlib
from io import StringIO
alexeib's avatar
alexeib committed
10
11
12
import unittest
from unittest.mock import MagicMock, patch

Myle Ott's avatar
Myle Ott committed
13
14
import torch

15
from fairseq import data, checkpoint_utils
alexeib's avatar
alexeib committed
16
17


Myle Ott's avatar
Myle Ott committed
18
def mock_trainer(epoch, num_updates, iterations_in_epoch):
alexeib's avatar
alexeib committed
19
    trainer = MagicMock()
Myle Ott's avatar
Myle Ott committed
20
21
22
23
24
25
26
    trainer.load_checkpoint.return_value = {
        'train_iterator': {
            'epoch': epoch,
            'iterations_in_epoch': iterations_in_epoch,
            'shuffle': False,
        },
    }
alexeib's avatar
alexeib committed
27
28
29
30
    trainer.get_num_updates.return_value = num_updates
    return trainer


Myle Ott's avatar
Myle Ott committed
31
32
33
34
35
36
37
38
39
def mock_dict():
    d = MagicMock()
    d.pad.return_value = 1
    d.eos.return_value = 2
    d.unk.return_value = 3
    return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
Myle Ott's avatar
Myle Ott committed
40
41
42
43
    tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
    tokens_ds = data.TokenBlockDataset(
        tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
    )
Myle Ott's avatar
Myle Ott committed
44
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
45
    dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
Myle Ott's avatar
Myle Ott committed
46
    epoch_itr = data.EpochBatchIterator(
47
48
        dataset=dataset,
        collate_fn=dataset.collater,
49
        batch_sampler=[[i] for i in range(epoch_size)],
Myle Ott's avatar
Myle Ott committed
50
51
    )
    return trainer, epoch_itr
alexeib's avatar
alexeib committed
52
53
54
55
56


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
alexeib's avatar
alexeib committed
57
58
        self.args_mock = MagicMock()
        self.args_mock.optimizer_overrides = '{}'
alexeib's avatar
alexeib committed
59
60
61
62
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
63
            'os.path.isabs': MagicMock(return_value=False),
alexeib's avatar
alexeib committed
64
65
66
67
68
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
Naman Goyal's avatar
Naman Goyal committed
69

Myle Ott's avatar
Myle Ott committed
70
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
71
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
Myle Ott's avatar
Myle Ott committed
72
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
Myle Ott's avatar
Myle Ott committed
73

Myle Ott's avatar
Myle Ott committed
74
            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
Naman Goyal's avatar
Naman Goyal committed
75

Myle Ott's avatar
Myle Ott committed
76
77
78
79
80
81
82
83
84
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
            self.assertEqual(epoch_itr.iterations_in_epoch, 51)
alexeib's avatar
alexeib committed
85

86
87
88
89
90
91
92
93
94
95
96
97
            for _ in range(150 - 52):
                next(itr)
            self.assertEqual(epoch_itr.iterations_in_epoch, 149)
            self.assertTrue(itr.has_next())
            next(itr)
            self.assertFalse(itr.has_next())

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertTrue(itr.has_next())
            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)

alexeib's avatar
alexeib committed
98
    def test_load_full_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
99
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
100
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
Myle Ott's avatar
Myle Ott committed
101
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
Myle Ott's avatar
Myle Ott committed
102

Myle Ott's avatar
Myle Ott committed
103
            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
Myle Ott's avatar
Myle Ott committed
104
105
106
107
108
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
109
110

    def test_load_no_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
111
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
112
            trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
Myle Ott's avatar
Myle Ott committed
113
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
Myle Ott's avatar
Myle Ott committed
114
115
            self.patches['os.path.isfile'].return_value = False

Myle Ott's avatar
Myle Ott committed
116
            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
Myle Ott's avatar
Myle Ott committed
117
118
119
120
121
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 1)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
122
123
124
125
126
127
128

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()