test_lightning_train_net.py 2.21 KB
Newer Older
1
2
3
4
5
6
7
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
import unittest

import numpy as np
Yanghan Wang's avatar
Yanghan Wang committed
8
import torch.distributed as dist
9
10
11
12
from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tools.lightning_train_net import main, FINAL_MODEL_CKPT
Yanghan Wang's avatar
Yanghan Wang committed
13
14
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
15
16
17


class TestLightningTrainNet(unittest.TestCase):
Kai Zhang's avatar
Kai Zhang committed
18
19
20
21
    def setUp(self):
        # set distributed backend to none to avoid spawning child process,
        # which doesn't inherit the temporary dataset
        patcher = unittest.mock.patch(
22
            "d2go.tools.lightning_train_net._get_accelerator", return_value=None
Kai Zhang's avatar
Kai Zhang committed
23
24
25
26
        )
        self.addCleanup(patcher.stop)
        patcher.start()

27
28
29
30
31
    def _get_cfg(self, tmp_dir) -> CfgNode:
        return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)

    @tempdir
    def test_train_net_main(self, root_dir):
32
        """tests the main training entry point."""
33
34
35
        cfg = self._get_cfg(root_dir)
        # set distributed backend to none to avoid spawning child process,
        # which doesn't inherit the temporary dataset
36
        main(cfg, root_dir)
37
38
39

    @tempdir
    def test_checkpointing(self, tmp_dir):
40
        """tests saving and loading from checkpoint."""
41
42
        cfg = self._get_cfg(tmp_dir)

43
        out = main(cfg, tmp_dir)
44
45
46
47
        ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
        expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
        for ckpt in expected_ckpts:
            self.assertIn(ckpt, ckpts)
48

49
50
51
52
53
        cfg2 = cfg.clone()
        cfg2.defrost()
        # load the last checkpoint from previous training
        cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")

54
55
        output_dir = os.path.join(tmp_dir, "output")
        out2 = main(cfg2, output_dir, eval_only=True)
56
57
58
59
        accuracy = flatten_config_dict(out.accuracy)
        accuracy2 = flatten_config_dict(out2.accuracy)
        for k in accuracy:
            np.testing.assert_equal(accuracy[k], accuracy2[k])
Yanghan Wang's avatar
Yanghan Wang committed
60
61
62
63

    def tearDown(self):
        if dist.is_initialized():
            dist.destroy_process_group()