test_lightning_train_net.py 2.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
import unittest

import numpy as np
from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask
11
from d2go.tools.lightning_train_net import FINAL_MODEL_CKPT, main
Yanghan Wang's avatar
Yanghan Wang committed
12
from d2go.utils.testing import meta_arch_helper as mah
13
from d2go.utils.testing.helper import enable_ddp_env, tempdir
14
15
16


class TestLightningTrainNet(unittest.TestCase):
Kai Zhang's avatar
Kai Zhang committed
17
18
19
20
    def setUp(self):
        # set distributed backend to none to avoid spawning child process,
        # which doesn't inherit the temporary dataset
        patcher = unittest.mock.patch(
21
            "d2go.tools.lightning_train_net._get_accelerator", return_value=None
Kai Zhang's avatar
Kai Zhang committed
22
23
24
25
        )
        self.addCleanup(patcher.stop)
        patcher.start()

26
27
28
29
    def _get_cfg(self, tmp_dir) -> CfgNode:
        return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)

    @tempdir
30
    @enable_ddp_env
31
    def test_train_net_main(self, root_dir):
32
        """tests the main training entry point."""
33
34
35
        cfg = self._get_cfg(root_dir)
        # set distributed backend to none to avoid spawning child process,
        # which doesn't inherit the temporary dataset
36
        main(cfg, root_dir)
37
38

    @tempdir
39
    @enable_ddp_env
40
    def test_checkpointing(self, tmp_dir):
41
        """tests saving and loading from checkpoint."""
42
43
        cfg = self._get_cfg(tmp_dir)

44
        out = main(cfg, tmp_dir)
45
46
47
48
        ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
        expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
        for ckpt in expected_ckpts:
            self.assertIn(ckpt, ckpts)
49

50
51
52
53
54
        cfg2 = cfg.clone()
        cfg2.defrost()
        # load the last checkpoint from previous training
        cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")

55
56
        output_dir = os.path.join(tmp_dir, "output")
        out2 = main(cfg2, output_dir, eval_only=True)
57
58
59
60
        accuracy = flatten_config_dict(out.accuracy)
        accuracy2 = flatten_config_dict(out2.accuracy)
        for k in accuracy:
            np.testing.assert_equal(accuracy[k], accuracy2[k])