"llama/llama.cpp/examples/llava/clip.cpp" did not exist on "de982616f1dde636e46b2cef2edd971b54ef7691"
test_runner_lightning_task.py 8.67 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import os
import unittest
from copy import deepcopy
from typing import Dict

import pytorch_lightning as pl  # type: ignore
import torch
12
from d2go.config import CfgNode, temp_defrost
13
from d2go.quantization.modeling import set_backend_and_create_qconfig
14
from d2go.registry.builtin import META_ARCH_REGISTRY
15
from d2go.runner import create_runner
16
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
facebook-github-bot's avatar
facebook-github-bot committed
17
from d2go.runner.lightning_task import GeneralizedRCNNTask
Yanghan Wang's avatar
Yanghan Wang committed
18
from d2go.utils.testing import meta_arch_helper as mah
19
from d2go.utils.testing.helper import tempdir
facebook-github-bot's avatar
facebook-github-bot committed
20
from detectron2.utils.events import EventStorage
21
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
facebook-github-bot's avatar
facebook-github-bot committed
22
from torch import Tensor
23
from torch.ao.quantization.quantize_fx import convert_fx, prepare_qat_fx
facebook-github-bot's avatar
facebook-github-bot committed
24
25
26
27


class TestLightningTask(unittest.TestCase):
    def _get_cfg(self, tmp_dir: str) -> CfgNode:
28
        cfg = mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
facebook-github-bot's avatar
facebook-github-bot committed
29
30
31
        cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER
        return cfg

32
33
34
35
36
37
38
    def _get_trainer(self, output_dir: str) -> pl.Trainer:
        checkpoint_callback = ModelCheckpoint(dirpath=output_dir, save_last=True)
        return pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
39
            logger=False,
40
41
        )

facebook-github-bot's avatar
facebook-github-bot committed
42
43
44
45
46
47
48
49
50
51
52
    def _compare_state_dict(
        self, state1: Dict[str, Tensor], state2: Dict[str, Tensor]
    ) -> bool:
        if state1.keys() != state2.keys():
            return False

        for k in state1:
            if not torch.allclose(state1[k], state2[k]):
                return False
        return True

53
54
55
56
    @tempdir
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

57
        trainer = self._get_trainer(tmp_dir)
58
59
60
61
62
63
64
65
66
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
facebook-github-bot's avatar
facebook-github-bot committed
67
            self.assertTrue(
68
69
70
                self._compare_state_dict(
                    task.model.state_dict(), task2.model.state_dict()
                )
facebook-github-bot's avatar
facebook-github-bot committed
71
72
            )

73
74
75
76
77
78
79
80
    @tempdir
    def test_train_ema(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        cfg.MODEL_EMA.DECAY = 0.7
        task = GeneralizedRCNNTask(cfg)
        init_state = deepcopy(task.model.state_dict())

81
        trainer = self._get_trainer(tmp_dir)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        for k, v in task.model.state_dict().items():
            init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)

        self.assertTrue(
            self._compare_state_dict(init_state, task.ema_state.state_dict())
        )

    @tempdir
    def test_load_ema_weights(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
98
        trainer = self._get_trainer(tmp_dir)
99
100
101
102
103
104
105
106
107
108
109
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt")
        )
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.ema_state.state_dict()
facebook-github-bot's avatar
facebook-github-bot committed
110
            )
111
        )
facebook-github-bot's avatar
facebook-github-bot committed
112

113
114
115
116
117
        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.model.state_dict()
118
            )
119
120
121
122
123
124
125
126
127
128
129
130
131
        )

    def test_create_runner(self):
        task_cls = create_runner(
            f"{GeneralizedRCNNTask.__module__}.{GeneralizedRCNNTask.__name__}"
        )
        self.assertTrue(task_cls == GeneralizedRCNNTask)

    @tempdir
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
132
        trainer = self._get_trainer(tmp_dir)
133
134
135
136
137
138
139
140
141
142
143
144
145
146

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
147
            self.assertTrue(
148
                self._compare_state_dict(model.state_dict(), task.model.state_dict())
149
            )
facebook-github-bot's avatar
facebook-github-bot committed
150

151
152
153
154
155
156
        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
157
158
            self.assertTrue(
                self._compare_state_dict(
159
                    model.state_dict(), task.ema_state.state_dict()
160
161
                )
            )
Kai Zhang's avatar
Kai Zhang committed
162
163

    @tempdir
164
165
166
    @unittest.skip(
        "FX Graph Mode Quantization API has been updated, re-enable the test after PyTorch 1.13 stable release"
    )
Kai Zhang's avatar
Kai Zhang committed
167
168
169
170
171
172
173
174
175
176
177
    def test_qat(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
            custom_config_dict = {"preserved_attributes": ["preserved_attr"]}

            def __init__(self, cfg):
                super().__init__(cfg)
                self.avgpool.preserved_attr = "foo"
                self.avgpool.not_preserved_attr = "bar"

            def prepare_for_quant(self, cfg):
178
                example_inputs = (torch.rand(1, 3, 3, 3),)
Kai Zhang's avatar
Kai Zhang committed
179
180
                self.avgpool = prepare_qat_fx(
                    self.avgpool,
181
                    {"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
182
                    example_inputs,
Kai Zhang's avatar
Kai Zhang committed
183
184
185
186
187
188
                    self.custom_config_dict,
                )
                return self

            def prepare_for_quant_convert(self, cfg):
                self.avgpool = convert_fx(
189
                    self.avgpool, convert_custom_config=self.custom_config_dict
Kai Zhang's avatar
Kai Zhang committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
                )
                return self

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
        cfg.QUANTIZATION.QAT.ENABLED = True
        task = GeneralizedRCNNTask(cfg)

        callbacks = [
            QuantizationAwareTraining.from_config(cfg),
            ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
        ]
        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=callbacks,
207
            logger=False,
Kai Zhang's avatar
Kai Zhang committed
208
209
210
211
212
213
214
215
216
217
218
219
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
        prepared_avgpool = task._prepared.model.avgpool
        self.assertEqual(prepared_avgpool.preserved_attr, "foo")
        self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))

        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
220
221
222
223
224
225
226
227
228

    @tempdir
    def test_meta_arch_training_step(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class DetMetaArchForWithTrainingStep(mah.DetMetaArchForTest):
            def training_step(self, batch, batch_idx, opt, manual_backward):
                assert batch
                assert opt
                assert manual_backward
Yanghan Wang's avatar
Yanghan Wang committed
229
                # We step the optimizer for progress tracking to occur
230
231
232
233
                # This is reflected in the Trainer's global_step property
                # which is used to determine when to stop training
                # when specifying the loop bounds with Trainer(max_steps=N)
                opt.step()
234
235
236
237
238
239
240
241
242
243
244
                return {"total_loss": 0.4}

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "DetMetaArchForWithTrainingStep"

        task = GeneralizedRCNNTask(cfg)

        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)