test_train.py 4.01 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Myle Ott's avatar
Myle Ott committed
8
9
import contextlib
from io import StringIO
alexeib's avatar
alexeib committed
10
11
12
import unittest
from unittest.mock import MagicMock, patch

Myle Ott's avatar
Myle Ott committed
13
14
15
16
import torch

from fairseq import data

alexeib's avatar
alexeib committed
17
18
19
import train


Myle Ott's avatar
Myle Ott committed
20
def mock_trainer(epoch, num_updates, iterations_in_epoch):
alexeib's avatar
alexeib committed
21
    trainer = MagicMock()
Myle Ott's avatar
Myle Ott committed
22
23
24
25
26
27
28
    trainer.load_checkpoint.return_value = {
        'train_iterator': {
            'epoch': epoch,
            'iterations_in_epoch': iterations_in_epoch,
            'shuffle': False,
        },
    }
alexeib's avatar
alexeib committed
29
30
31
32
    trainer.get_num_updates.return_value = num_updates
    return trainer


Myle Ott's avatar
Myle Ott committed
33
34
35
36
37
38
39
40
41
def mock_dict():
    d = MagicMock()
    d.pad.return_value = 1
    d.eos.return_value = 2
    d.unk.return_value = 3
    return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
Myle Ott's avatar
Myle Ott committed
42
43
44
45
    tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
    tokens_ds = data.TokenBlockDataset(
        tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
    )
Myle Ott's avatar
Myle Ott committed
46
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
47
    dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
Myle Ott's avatar
Myle Ott committed
48
    epoch_itr = data.EpochBatchIterator(
49
50
        dataset=dataset,
        collate_fn=dataset.collater,
51
        batch_sampler=[[i] for i in range(epoch_size)],
Myle Ott's avatar
Myle Ott committed
52
53
    )
    return trainer, epoch_itr
alexeib's avatar
alexeib committed
54
55
56
57
58


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
alexeib's avatar
alexeib committed
59
60
        self.args_mock = MagicMock()
        self.args_mock.optimizer_overrides = '{}'
alexeib's avatar
alexeib committed
61
62
63
64
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
65
            'os.path.isabs': MagicMock(return_value=False),
alexeib's avatar
alexeib committed
66
67
68
69
70
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
Naman Goyal's avatar
Naman Goyal committed
71

Myle Ott's avatar
Myle Ott committed
72
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
73
74
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)

Naman Goyal's avatar
Naman Goyal committed
75
76
77
            with patch('train.reload_train', return_value=epoch_itr):
                train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)

Myle Ott's avatar
Myle Ott committed
78
79
80
81
82
83
84
85
86
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
            self.assertEqual(epoch_itr.iterations_in_epoch, 51)
alexeib's avatar
alexeib committed
87
88

    def test_load_full_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
89
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
90
91
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)

Naman Goyal's avatar
Naman Goyal committed
92
93
            with patch('train.reload_train', return_value=epoch_itr):
                train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
Myle Ott's avatar
Myle Ott committed
94
95
96
97
98
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
99
100

    def test_load_no_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
101
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
102
            trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
Myle Ott's avatar
Myle Ott committed
103
104
            self.patches['os.path.isfile'].return_value = False

Naman Goyal's avatar
Naman Goyal committed
105
            train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
Myle Ott's avatar
Myle Ott committed
106
107
108
109
110
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 1)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
111
112
113
114
115
116
117

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()