test_train.py 2.02 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import unittest
9

alexeib's avatar
alexeib committed
10
11
12
13
14
from unittest.mock import MagicMock, patch

import train


alexeib's avatar
alexeib committed
15
def mock_trainer(epoch, num_updates, end_of_epoch):
alexeib's avatar
alexeib committed
16
    trainer = MagicMock()
alexeib's avatar
alexeib committed
17
    trainer.load_checkpoint.return_value = {'epoch': epoch, 'end_of_epoch': end_of_epoch}
alexeib's avatar
alexeib committed
18
19
20
21
22
23
    trainer.get_num_updates.return_value = num_updates
    return trainer


def mock_loader(length):
    loader = MagicMock()
24
    loader.__next__.return_value = list(range(length))
alexeib's avatar
alexeib committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    return loader


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
alexeib's avatar
alexeib committed
40
        trainer = mock_trainer(2, 200, False)
alexeib's avatar
alexeib committed
41
42
43
        loader = mock_loader(150)
        epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
        self.assertEqual(epoch, 2)
44
        self.assertEqual(next(ds), 50)
alexeib's avatar
alexeib committed
45
46

    def test_load_full_checkpoint(self):
alexeib's avatar
alexeib committed
47
        trainer = mock_trainer(2, 300, True)
alexeib's avatar
alexeib committed
48
49
        loader = mock_loader(150)
        epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
alexeib's avatar
alexeib committed
50
        self.assertEqual(epoch, 3)
51
        self.assertEqual(next(iter(ds)), 0)
alexeib's avatar
alexeib committed
52
53

    def test_load_no_checkpoint(self):
alexeib's avatar
alexeib committed
54
        trainer = mock_trainer(0, 0, False)
alexeib's avatar
alexeib committed
55
56
57
58
59
        loader = mock_loader(150)
        self.patches['os.path.isfile'].return_value = False

        epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
        self.assertEqual(epoch, 1)
60
        self.assertEqual(next(iter(ds)), 0)
alexeib's avatar
alexeib committed
61
62
63
64
65
66
67

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()