test_train.py 2.13 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import unittest
from unittest.mock import MagicMock, patch

import train


def mock_trainer(epoch, num_updates):
    trainer = MagicMock()
    trainer.load_checkpoint.return_value = {'epoch': epoch}
    trainer.get_num_updates.return_value = num_updates
    return trainer


def mock_loader(length):
    ds = MagicMock()
    ds.__len__.return_value = length
    loader = MagicMock()
    loader.__next__.return_value = ds
    return loader


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
        trainer = mock_trainer(2, 200)
        loader = mock_loader(150)
        epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
        self.assertEqual(epoch, 2)
        self.assertEqual(len(ds), 50)
        self.assertNotIsInstance(ds, MagicMock)

    def test_load_full_checkpoint(self):
        trainer = mock_trainer(2, 150)
        loader = mock_loader(150)
        epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
        self.assertEqual(epoch, 2)
        self.assertEqual(len(ds), 150)
        self.assertIsInstance(ds, MagicMock)

    def test_load_no_checkpoint(self):
        trainer = mock_trainer(0, 0)
        loader = mock_loader(150)
        self.patches['os.path.isfile'].return_value = False

        epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
        self.assertEqual(epoch, 1)
        self.assertEqual(len(ds), 150)
        self.assertIsInstance(ds, MagicMock)

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()