test_train.py 3.75 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Myle Ott's avatar
Myle Ott committed
8
9
import contextlib
from io import StringIO
alexeib's avatar
alexeib committed
10
11
12
import unittest
from unittest.mock import MagicMock, patch

Myle Ott's avatar
Myle Ott committed
13
14
15
16
import torch

from fairseq import data

alexeib's avatar
alexeib committed
17
18
19
import train


Myle Ott's avatar
Myle Ott committed
20
def mock_trainer(epoch, num_updates, iterations_in_epoch):
alexeib's avatar
alexeib committed
21
    trainer = MagicMock()
Myle Ott's avatar
Myle Ott committed
22
23
24
25
26
27
28
    trainer.load_checkpoint.return_value = {
        'train_iterator': {
            'epoch': epoch,
            'iterations_in_epoch': iterations_in_epoch,
            'shuffle': False,
        },
    }
alexeib's avatar
alexeib committed
29
30
31
32
    trainer.get_num_updates.return_value = num_updates
    return trainer


Myle Ott's avatar
Myle Ott committed
33
34
35
36
37
38
39
40
41
42
def mock_dict():
    d = MagicMock()
    d.pad.return_value = 1
    d.eos.return_value = 2
    d.unk.return_value = 3
    return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size)))
43
    tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False)
Myle Ott's avatar
Myle Ott committed
44
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
45
    dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
Myle Ott's avatar
Myle Ott committed
46
    epoch_itr = data.EpochBatchIterator(
47
48
        dataset=dataset,
        collate_fn=dataset.collater,
49
        batch_sampler=[[i] for i in range(epoch_size)],
Myle Ott's avatar
Myle Ott committed
50
51
    )
    return trainer, epoch_itr
alexeib's avatar
alexeib committed
52
53
54
55
56


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
alexeib's avatar
alexeib committed
57
58
        self.args_mock = MagicMock()
        self.args_mock.optimizer_overrides = '{}'
alexeib's avatar
alexeib committed
59
60
61
62
63
64
65
66
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

alexeib's avatar
alexeib committed
67

alexeib's avatar
alexeib committed
68
    def test_load_partial_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
69
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
70
71
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)

alexeib's avatar
alexeib committed
72
            train.load_checkpoint(self.args_mock, trainer, epoch_itr)
Myle Ott's avatar
Myle Ott committed
73
74
75
76
77
78
79
80
81
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
            self.assertEqual(epoch_itr.iterations_in_epoch, 51)
alexeib's avatar
alexeib committed
82
83

    def test_load_full_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
84
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
85
86
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)

alexeib's avatar
alexeib committed
87
            train.load_checkpoint(self.args_mock, trainer, epoch_itr)
Myle Ott's avatar
Myle Ott committed
88
89
90
91
92
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
93
94

    def test_load_no_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
95
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
96
            trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
Myle Ott's avatar
Myle Ott committed
97
98
            self.patches['os.path.isfile'].return_value = False

alexeib's avatar
alexeib committed
99
            train.load_checkpoint(self.args_mock, trainer, epoch_itr)
Myle Ott's avatar
Myle Ott committed
100
101
102
103
104
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 1)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
105
106
107
108
109
110
111

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()