test_runner_lightning_task.py 7.93 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import os
import unittest
from copy import deepcopy
from typing import Dict

import pytorch_lightning as pl  # type: ignore
import torch
12
13
from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner
Kai Zhang's avatar
Kai Zhang committed
14
15
16
from d2go.runner.callbacks.quantization import (
    QuantizationAwareTraining,
)
facebook-github-bot's avatar
facebook-github-bot committed
17
from d2go.runner.lightning_task import GeneralizedRCNNTask
Yanghan Wang's avatar
Yanghan Wang committed
18
from d2go.utils.testing import meta_arch_helper as mah
19
from d2go.utils.testing.helper import tempdir
Kai Zhang's avatar
Kai Zhang committed
20
21
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
facebook-github-bot's avatar
facebook-github-bot committed
22
from detectron2.utils.events import EventStorage
23
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
facebook-github-bot's avatar
facebook-github-bot committed
24
from torch import Tensor
Kai Zhang's avatar
Kai Zhang committed
25
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
facebook-github-bot's avatar
facebook-github-bot committed
26
27
28
29


class TestLightningTask(unittest.TestCase):
    def _get_cfg(self, tmp_dir: str) -> CfgNode:
30
        cfg = mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
facebook-github-bot's avatar
facebook-github-bot committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER
        return cfg

    def _compare_state_dict(
        self, state1: Dict[str, Tensor], state2: Dict[str, Tensor]
    ) -> bool:
        if state1.keys() != state2.keys():
            return False

        for k in state1:
            if not torch.allclose(state1[k], state2[k]):
                return False
        return True

45
46
47
48
49
50
51
52
53
    @tempdir
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
        params = {
            "max_steps": 1,
            "limit_train_batches": 1,
            "num_sanity_val_steps": 0,
54
            "callbacks": [checkpoint_callback],
55
56
57
58
59
60
61
62
63
64
65
        }
        trainer = pl.Trainer(**params)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
facebook-github-bot's avatar
facebook-github-bot committed
66
            self.assertTrue(
67
68
69
                self._compare_state_dict(
                    task.model.state_dict(), task2.model.state_dict()
                )
facebook-github-bot's avatar
facebook-github-bot committed
70
71
            )

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    @tempdir
    def test_train_ema(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        cfg.MODEL_EMA.DECAY = 0.7
        task = GeneralizedRCNNTask(cfg)
        init_state = deepcopy(task.model.state_dict())

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        for k, v in task.model.state_dict().items():
            init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)

        self.assertTrue(
            self._compare_state_dict(init_state, task.ema_state.state_dict())
        )

    @tempdir
    def test_load_ema_weights(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(
            dirpath=task.cfg.OUTPUT_DIR, save_last=True
        )

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt")
        )
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.ema_state.state_dict()
facebook-github-bot's avatar
facebook-github-bot committed
123
            )
124
        )
facebook-github-bot's avatar
facebook-github-bot committed
125

126
127
128
129
130
        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.model.state_dict()
131
            )
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        )

    def test_create_runner(self):
        task_cls = create_runner(
            f"{GeneralizedRCNNTask.__module__}.{GeneralizedRCNNTask.__name__}"
        )
        self.assertTrue(task_cls == GeneralizedRCNNTask)

    @tempdir
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(
            dirpath=task.cfg.OUTPUT_DIR, save_last=True
        )

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
169
            self.assertTrue(
170
                self._compare_state_dict(model.state_dict(), task.model.state_dict())
171
            )
facebook-github-bot's avatar
facebook-github-bot committed
172

173
174
175
176
177
178
        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
179
180
            self.assertTrue(
                self._compare_state_dict(
181
                    model.state_dict(), task.ema_state.state_dict()
182
183
                )
            )
Kai Zhang's avatar
Kai Zhang committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    @tempdir
    def test_qat(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
            custom_config_dict = {"preserved_attributes": ["preserved_attr"]}

            def __init__(self, cfg):
                super().__init__(cfg)
                self.avgpool.preserved_attr = "foo"
                self.avgpool.not_preserved_attr = "bar"

            def prepare_for_quant(self, cfg):
                self.avgpool = prepare_qat_fx(
                    self.avgpool,
                    {"": torch.quantization.get_default_qat_qconfig()},
                    self.custom_config_dict,
                )
                return self

            def prepare_for_quant_convert(self, cfg):
                self.avgpool = convert_fx(
                    self.avgpool, convert_custom_config_dict=self.custom_config_dict
                )
                return self

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
        cfg.QUANTIZATION.QAT.ENABLED = True
        task = GeneralizedRCNNTask(cfg)

        callbacks = [
            QuantizationAwareTraining.from_config(cfg),
            ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
        ]
        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=callbacks,
            logger=None,
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
        prepared_avgpool = task._prepared.model.avgpool
        self.assertEqual(prepared_avgpool.preserved_attr, "foo")
        self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))

        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))