test_train.py 3.53 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Myle Ott's avatar
Myle Ott committed
8
9
import contextlib
from io import StringIO
alexeib's avatar
alexeib committed
10
11
12
import unittest
from unittest.mock import MagicMock, patch

Myle Ott's avatar
Myle Ott committed
13
14
15
16
import torch

from fairseq import data

alexeib's avatar
alexeib committed
17
18
19
import train


Myle Ott's avatar
Myle Ott committed
20
def mock_trainer(epoch, num_updates, iterations_in_epoch):
alexeib's avatar
alexeib committed
21
    trainer = MagicMock()
Myle Ott's avatar
Myle Ott committed
22
23
24
25
26
27
28
    trainer.load_checkpoint.return_value = {
        'train_iterator': {
            'epoch': epoch,
            'iterations_in_epoch': iterations_in_epoch,
            'shuffle': False,
        },
    }
alexeib's avatar
alexeib committed
29
30
31
32
    trainer.get_num_updates.return_value = num_updates
    return trainer


Myle Ott's avatar
Myle Ott committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def mock_dict():
    d = MagicMock()
    d.pad.return_value = 1
    d.eos.return_value = 2
    d.unk.return_value = 3
    return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size)))
    tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    epoch_itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
        max_tokens=1,
    )
    return trainer, epoch_itr
alexeib's avatar
alexeib committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
64
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
65
66
67
68
69
70
71
72
73
74
75
76
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)

            train.load_checkpoint(MagicMock(), trainer, epoch_itr)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
            self.assertEqual(epoch_itr.iterations_in_epoch, 51)
alexeib's avatar
alexeib committed
77
78

    def test_load_full_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
79
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
80
81
82
83
84
85
86
87
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)

            train.load_checkpoint(MagicMock(), trainer, epoch_itr)
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
88
89

    def test_load_no_checkpoint(self):
Myle Ott's avatar
Myle Ott committed
90
        with contextlib.redirect_stdout(StringIO()):
Myle Ott's avatar
Myle Ott committed
91
            trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
Myle Ott's avatar
Myle Ott committed
92
93
            self.patches['os.path.isfile'].return_value = False

Myle Ott's avatar
Myle Ott committed
94
95
96
97
98
99
            train.load_checkpoint(MagicMock(), trainer, epoch_itr)
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 1)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
alexeib's avatar
alexeib committed
100
101
102
103
104
105
106

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()