test_lightning_train_net.py 1.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
import tempfile
import unittest

import numpy as np
from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tools.lightning_train_net import main, FINAL_MODEL_CKPT
Yanghan Wang's avatar
Yanghan Wang committed
13
14
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


class TestLightningTrainNet(unittest.TestCase):
    def _get_cfg(self, tmp_dir) -> CfgNode:
        return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)

    @tempdir
    def test_train_net_main(self, root_dir):
        """ tests the main training entry point. """
        cfg = self._get_cfg(root_dir)
        # set distributed backend to none to avoid spawning child process,
        # which doesn't inherit the temporary dataset
        main(cfg, accelerator=None)

    @tempdir
    def test_checkpointing(self, tmp_dir):
        """ tests saving and loading from checkpoint. """
        cfg = self._get_cfg(tmp_dir)

        out = main(cfg, accelerator=None)
        ckpts = [file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")]
        self.assertCountEqual(
            [
                "last.ckpt",
                FINAL_MODEL_CKPT,
            ],
            ckpts,
        )

44
45
46
47
48
49
50
51
52
53
54
55
56
        tmp_dir2 = tempfile.TemporaryDirectory()  # noqa to avoid flaky test
        cfg2 = cfg.clone()
        cfg2.defrost()
        cfg2.OUTPUT_DIR = tmp_dir2.name
        # load the last checkpoint from previous training
        cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")

        out2 = main(cfg2, accelerator=None, eval_only=True)
        accuracy = flatten_config_dict(out.accuracy)
        accuracy2 = flatten_config_dict(out2.accuracy)
        for k in accuracy:
            np.testing.assert_equal(accuracy[k], accuracy2[k])
        tmp_dir2.cleanup()