test_runner_lightning_task.py 9.89 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import os
import unittest
from copy import deepcopy
8
from tempfile import TemporaryDirectory
facebook-github-bot's avatar
facebook-github-bot committed
9
10
11
12
from typing import Dict

import pytorch_lightning as pl  # type: ignore
import torch
13
from d2go.config import CfgNode, temp_defrost
14
from d2go.quantization.modeling import set_backend_and_create_qconfig
15
from d2go.registry.builtin import META_ARCH_REGISTRY
16
from d2go.runner import create_runner
17
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
facebook-github-bot's avatar
facebook-github-bot committed
18
from d2go.runner.lightning_task import GeneralizedRCNNTask
Yanghan Wang's avatar
Yanghan Wang committed
19
from d2go.utils.testing import meta_arch_helper as mah
20
from d2go.utils.testing.helper import tempdir
facebook-github-bot's avatar
facebook-github-bot committed
21
from detectron2.utils.events import EventStorage
22
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
facebook-github-bot's avatar
facebook-github-bot committed
23
from torch import Tensor
24
from torch.ao.quantization.quantize_fx import convert_fx, prepare_qat_fx
facebook-github-bot's avatar
facebook-github-bot committed
25
26
27
28


class TestLightningTask(unittest.TestCase):
    def _get_cfg(self, tmp_dir: str) -> CfgNode:
29
        cfg = mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
facebook-github-bot's avatar
facebook-github-bot committed
30
31
32
        cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER
        return cfg

33
34
35
36
37
38
39
    def _get_trainer(self, output_dir: str) -> pl.Trainer:
        checkpoint_callback = ModelCheckpoint(dirpath=output_dir, save_last=True)
        return pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
40
            logger=False,
41
42
        )

facebook-github-bot's avatar
facebook-github-bot committed
43
44
45
46
47
48
49
50
51
52
53
    def _compare_state_dict(
        self, state1: Dict[str, Tensor], state2: Dict[str, Tensor]
    ) -> bool:
        if state1.keys() != state2.keys():
            return False

        for k in state1:
            if not torch.allclose(state1[k], state2[k]):
                return False
        return True

54
55
56
57
    @tempdir
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

58
        trainer = self._get_trainer(tmp_dir)
59
60
61
62
63
64
65
66
67
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
facebook-github-bot's avatar
facebook-github-bot committed
68
            self.assertTrue(
69
70
71
                self._compare_state_dict(
                    task.model.state_dict(), task2.model.state_dict()
                )
facebook-github-bot's avatar
facebook-github-bot committed
72
73
            )

74
75
76
77
78
79
80
81
    @tempdir
    def test_train_ema(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        cfg.MODEL_EMA.DECAY = 0.7
        task = GeneralizedRCNNTask(cfg)
        init_state = deepcopy(task.model.state_dict())

82
        trainer = self._get_trainer(tmp_dir)
83
84
85
86
87
88
89
90
91
92
93
94
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        for k, v in task.model.state_dict().items():
            init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)

        self.assertTrue(
            self._compare_state_dict(init_state, task.ema_state.state_dict())
        )

    @tempdir
95
    def test_load_ema_weights(self, tmp_dir) -> None:
96
97
98
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
99
        trainer = self._get_trainer(tmp_dir)
100
101
102
103
104
105
106
107
108
109
110
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt")
        )
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.ema_state.state_dict()
facebook-github-bot's avatar
facebook-github-bot committed
111
            )
112
        )
facebook-github-bot's avatar
facebook-github-bot committed
113

114
115
116
117
118
        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.model.state_dict()
119
            )
120
121
        )

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    @tempdir
    def test_ema_eval_only_mode(self, tmp_dir: TemporaryDirectory) -> None:
        """Train one model for one iteration, then check if the
        second task is loaded correctly from config and applied to model.x"""
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.MODELING_HOOKS = ["EMA"]
        cfg.MODEL_EMA.ENABLED = True

        task = GeneralizedRCNNTask(cfg)
        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        cfg2 = self._get_cfg(tmp_dir)
        cfg2.MODEL.MODELING_HOOKS = ["EMA"]
        cfg2.MODEL_EMA.ENABLED = True
        cfg2.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
        cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")

        task2 = GeneralizedRCNNTask.from_config(cfg2)

        self.assertTrue(task2.ema_state, "EMA state is not loaded from checkpoint.")
        self.assertTrue(
            len(task2.ema_state.state_dict()) > 0, "EMA state should not be empty."
        )
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.model.state_dict()
            ),
            "Task loaded from config should apply the ema_state to the model.",
        )

156
157
158
159
160
161
162
163
164
165
166
    def test_create_runner(self):
        task_cls = create_runner(
            f"{GeneralizedRCNNTask.__module__}.{GeneralizedRCNNTask.__name__}"
        )
        self.assertTrue(task_cls == GeneralizedRCNNTask)

    @tempdir
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
167
        trainer = self._get_trainer(tmp_dir)
168
169
170
171
172
173
174
175
176
177
178
179
180
181

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
182
            self.assertTrue(
183
                self._compare_state_dict(model.state_dict(), task.model.state_dict())
184
            )
facebook-github-bot's avatar
facebook-github-bot committed
185

186
187
188
189
190
191
        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
192
193
            self.assertTrue(
                self._compare_state_dict(
194
                    model.state_dict(), task.ema_state.state_dict()
195
196
                )
            )
Kai Zhang's avatar
Kai Zhang committed
197
198
199
200
201
202
203
204
205
206
207
208

    @tempdir
    def test_qat(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
            custom_config_dict = {"preserved_attributes": ["preserved_attr"]}

            def __init__(self, cfg):
                super().__init__(cfg)
                self.avgpool.preserved_attr = "foo"
                self.avgpool.not_preserved_attr = "bar"

209
            def custom_prepare_fx(self, cfg, is_qat, example_input=None):
210
                example_inputs = (torch.rand(1, 3, 3, 3),)
Kai Zhang's avatar
Kai Zhang committed
211
212
                self.avgpool = prepare_qat_fx(
                    self.avgpool,
213
                    {"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
214
                    example_inputs,
Kai Zhang's avatar
Kai Zhang committed
215
216
217
218
                    self.custom_config_dict,
                )
                return self

219
            def custom_convert_fx(self, cfg):
Kai Zhang's avatar
Kai Zhang committed
220
                self.avgpool = convert_fx(
221
                    self.avgpool, convert_custom_config=self.custom_config_dict
Kai Zhang's avatar
Kai Zhang committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                )
                return self

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
        cfg.QUANTIZATION.QAT.ENABLED = True
        task = GeneralizedRCNNTask(cfg)

        callbacks = [
            QuantizationAwareTraining.from_config(cfg),
            ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
        ]
        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=callbacks,
239
            logger=False,
Kai Zhang's avatar
Kai Zhang committed
240
241
242
243
244
245
246
247
248
249
250
251
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
        prepared_avgpool = task._prepared.model.avgpool
        self.assertEqual(prepared_avgpool.preserved_attr, "foo")
        self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))

        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
252
253
254
255
256
257
258
259
260

    @tempdir
    def test_meta_arch_training_step(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class DetMetaArchForWithTrainingStep(mah.DetMetaArchForTest):
            def training_step(self, batch, batch_idx, opt, manual_backward):
                assert batch
                assert opt
                assert manual_backward
Yanghan Wang's avatar
Yanghan Wang committed
261
                # We step the optimizer for progress tracking to occur
262
263
264
265
                # This is reflected in the Trainer's global_step property
                # which is used to determine when to stop training
                # when specifying the loop bounds with Trainer(max_steps=N)
                opt.step()
266
267
268
269
270
271
272
273
274
275
276
                return {"total_loss": 0.4}

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "DetMetaArchForWithTrainingStep"

        task = GeneralizedRCNNTask(cfg)

        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)