test_train.py 3.66 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Myle Ott's avatar
Myle Ott committed
8
9
import contextlib
from io import StringIO
alexeib's avatar
alexeib committed
10
11
12
import unittest
from unittest.mock import MagicMock, patch

Myle Ott's avatar
Myle Ott committed
13
14
15
16
import torch

from fairseq import data

alexeib's avatar
alexeib committed
17
18
19
import train


Myle Ott's avatar
Myle Ott committed
20
def mock_trainer(epoch, num_updates, iterations_in_epoch):
alexeib's avatar
alexeib committed
21
    trainer = MagicMock()
Myle Ott's avatar
Myle Ott committed
22
23
24
25
26
27
28
    trainer.load_checkpoint.return_value = {
        'train_iterator': {
            'epoch': epoch,
            'iterations_in_epoch': iterations_in_epoch,
            'shuffle': False,
        },
    }
alexeib's avatar
alexeib committed
29
30
31
32
    trainer.get_num_updates.return_value = num_updates
    return trainer


Myle Ott's avatar
Myle Ott committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def mock_dict():
    d = MagicMock()
    d.pad.return_value = 1
    d.eos.return_value = 2
    d.unk.return_value = 3
    return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size)))
    tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    epoch_itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
47
        batch_sampler=[[i] for i in range(epoch_size)],
Myle Ott's avatar
Myle Ott committed
48
49
    )
    return trainer, epoch_itr
alexeib's avatar
alexeib committed
50
51
52
53
54


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
alexeib's avatar
alexeib committed
55
56
        self.args_mock = MagicMock()
        self.args_mock.optimizer_overrides = '{}'
alexeib's avatar
alexeib committed
57
58
59
60
61
62
63
64
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

alexeib's avatar
alexeib committed
65

alexeib's avatar
alexeib committed
66
    def test_load_partial_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
67
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
68
69
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)

alexeib's avatar
alexeib committed
70
            train.load_checkpoint(self.args_mock, trainer, epoch_itr)
Myle Ott's avatar
Myle Ott committed
71
72
73
74
75
76
77
78
79
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
            self.assertEqual(epoch_itr.iterations_in_epoch, 51)
alexeib's avatar
alexeib committed
80
81

    def test_load_full_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
82
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
83
84
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)

alexeib's avatar
alexeib committed
85
            train.load_checkpoint(self.args_mock, trainer, epoch_itr)
Myle Ott's avatar
Myle Ott committed
86
87
88
89
90
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
91
92

    def test_load_no_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
93
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
94
            trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
Myle Ott's avatar
Myle Ott committed
95
96
            self.patches['os.path.isfile'].return_value = False

alexeib's avatar
alexeib committed
97
            train.load_checkpoint(self.args_mock, trainer, epoch_itr)
Myle Ott's avatar
Myle Ott committed
98
99
100
101
102
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 1)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
103
104
105
106
107
108
109

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()