"references/vscode:/vscode.git/clone" did not exist on "6e535db255cee3ce878dd7a54dda01d4ec8932c1"
test_train.py 2.28 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Myle Ott's avatar
Myle Ott committed
8
9
import contextlib
from io import StringIO
alexeib's avatar
alexeib committed
10
import unittest
11

alexeib's avatar
alexeib committed
12
13
14
15
16
from unittest.mock import MagicMock, patch

import train


alexeib's avatar
alexeib committed
17
def mock_trainer(epoch, num_updates, end_of_epoch):
alexeib's avatar
alexeib committed
18
    trainer = MagicMock()
alexeib's avatar
alexeib committed
19
    trainer.load_checkpoint.return_value = {'epoch': epoch, 'end_of_epoch': end_of_epoch}
alexeib's avatar
alexeib committed
20
21
22
23
24
25
    trainer.get_num_updates.return_value = num_updates
    return trainer


def mock_loader(length):
    loader = MagicMock()
26
    loader.__next__.return_value = list(range(length))
alexeib's avatar
alexeib committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    return loader


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
42
43
44
45
46
47
        with contextlib.redirect_stdout(StringIO()):
            trainer = mock_trainer(2, 200, False)
            loader = mock_loader(150)
            epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
            self.assertEqual(epoch, 2)
            self.assertEqual(next(ds), 50)
alexeib's avatar
alexeib committed
48
49

    def test_load_full_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
50
51
52
53
54
55
        with contextlib.redirect_stdout(StringIO()):
            trainer = mock_trainer(2, 300, True)
            loader = mock_loader(150)
            epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
            self.assertEqual(epoch, 3)
            self.assertEqual(next(iter(ds)), 0)
alexeib's avatar
alexeib committed
56
57

    def test_load_no_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
58
59
60
61
62
63
64
65
        with contextlib.redirect_stdout(StringIO()):
            trainer = mock_trainer(0, 0, False)
            loader = mock_loader(150)
            self.patches['os.path.isfile'].return_value = False

            epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
            self.assertEqual(epoch, 1)
            self.assertEqual(next(iter(ds)), 0)
alexeib's avatar
alexeib committed
66
67
68
69
70
71
72

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()