"tests/vscode:/vscode.git/clone" did not exist on "ccf860f1db38b839db9dcde206b6b5091ac50385"
test_lightning_train_net.py 2.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
import unittest

import numpy as np
from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tools.lightning_train_net import main, FINAL_MODEL_CKPT
Yanghan Wang's avatar
Yanghan Wang committed
12
13
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
14
15
16


class TestLightningTrainNet(unittest.TestCase):
Kai Zhang's avatar
Kai Zhang committed
17
18
19
20
21
22
23
24
25
    def setUp(self):
        # set distributed backend to none to avoid spawning child process,
        # which doesn't inherit the temporary dataset
        patcher = unittest.mock.patch(
            "d2go.tools.lightning_train_net.get_accelerator", return_value=None
        )
        self.addCleanup(patcher.stop)
        patcher.start()

26
27
28
29
30
31
32
33
34
    def _get_cfg(self, tmp_dir) -> CfgNode:
        return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)

    @tempdir
    def test_train_net_main(self, root_dir):
        """ tests the main training entry point. """
        cfg = self._get_cfg(root_dir)
        # set distributed backend to none to avoid spawning child process,
        # which doesn't inherit the temporary dataset
Kai Zhang's avatar
Kai Zhang committed
35
        main(cfg)
36
37
38
39
40
41

    @tempdir
    def test_checkpointing(self, tmp_dir):
        """ tests saving and loading from checkpoint. """
        cfg = self._get_cfg(tmp_dir)

Kai Zhang's avatar
Kai Zhang committed
42
        out = main(cfg)
43
44
45
46
47
48
49
50
51
        ckpts = [file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")]
        self.assertCountEqual(
            [
                "last.ckpt",
                FINAL_MODEL_CKPT,
            ],
            ckpts,
        )

52
53
        cfg2 = cfg.clone()
        cfg2.defrost()
Sam Tsai's avatar
Sam Tsai committed
54
        cfg2.OUTPUT_DIR = os.path.join(tmp_dir, 'output')
55
56
57
        # load the last checkpoint from previous training
        cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")

Kai Zhang's avatar
Kai Zhang committed
58
        out2 = main(cfg2, eval_only=True)
59
60
61
62
        accuracy = flatten_config_dict(out.accuracy)
        accuracy2 = flatten_config_dict(out2.accuracy)
        for k in accuracy:
            np.testing.assert_equal(accuracy[k], accuracy2[k])