test_runner_lightning_task.py 5.69 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import os
import unittest
from copy import deepcopy
from typing import Dict

import pytorch_lightning as pl  # type: ignore
import torch
12
13
from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner
facebook-github-bot's avatar
facebook-github-bot committed
14
from d2go.runner.lightning_task import GeneralizedRCNNTask
Yanghan Wang's avatar
Yanghan Wang committed
15
from d2go.utils.testing import meta_arch_helper as mah
16
from d2go.utils.testing.helper import tempdir
facebook-github-bot's avatar
facebook-github-bot committed
17
from detectron2.utils.events import EventStorage
18
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
facebook-github-bot's avatar
facebook-github-bot committed
19
20
21
22
23
from torch import Tensor


class TestLightningTask(unittest.TestCase):
    def _get_cfg(self, tmp_dir: str) -> CfgNode:
24
        cfg = mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
facebook-github-bot's avatar
facebook-github-bot committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER
        return cfg

    def _compare_state_dict(
        self, state1: Dict[str, Tensor], state2: Dict[str, Tensor]
    ) -> bool:
        if state1.keys() != state2.keys():
            return False

        for k in state1:
            if not torch.allclose(state1[k], state2[k]):
                return False
        return True

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    @tempdir
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
        params = {
            "max_steps": 1,
            "limit_train_batches": 1,
            "num_sanity_val_steps": 0,
            "checkpoint_callback": checkpoint_callback,
        }
        trainer = pl.Trainer(**params)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
facebook-github-bot's avatar
facebook-github-bot committed
60
            self.assertTrue(
61
62
63
                self._compare_state_dict(
                    task.model.state_dict(), task2.model.state_dict()
                )
facebook-github-bot's avatar
facebook-github-bot committed
64
65
            )

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    @tempdir
    def test_train_ema(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        cfg.MODEL_EMA.DECAY = 0.7
        task = GeneralizedRCNNTask(cfg)
        init_state = deepcopy(task.model.state_dict())

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        for k, v in task.model.state_dict().items():
            init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)

        self.assertTrue(
            self._compare_state_dict(init_state, task.ema_state.state_dict())
        )

    @tempdir
    def test_load_ema_weights(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(
            dirpath=task.cfg.OUTPUT_DIR, save_last=True
        )

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt")
        )
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.ema_state.state_dict()
facebook-github-bot's avatar
facebook-github-bot committed
117
            )
118
        )
facebook-github-bot's avatar
facebook-github-bot committed
119

120
121
122
123
124
        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.model.state_dict()
125
            )
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        )

    def test_create_runner(self):
        task_cls = create_runner(
            f"{GeneralizedRCNNTask.__module__}.{GeneralizedRCNNTask.__name__}"
        )
        self.assertTrue(task_cls == GeneralizedRCNNTask)

    @tempdir
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(
            dirpath=task.cfg.OUTPUT_DIR, save_last=True
        )

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
163
            self.assertTrue(
164
                self._compare_state_dict(model.state_dict(), task.model.state_dict())
165
            )
facebook-github-bot's avatar
facebook-github-bot committed
166

167
168
169
170
171
172
        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
173
174
            self.assertTrue(
                self._compare_state_dict(
175
                    model.state_dict(), task.ema_state.state_dict()
176
177
                )
            )