visualization_hook.py 9.07 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional, Sequence

import mmcv
ChaimZhu's avatar
ChaimZhu committed
7
import numpy as np
8
from mmengine.fileio import get
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.utils import mkdir_or_exist
from mmengine.visualization import Visualizer

from mmdet3d.registry import HOOKS
from mmdet3d.structures import Det3DDataSample


@HOOKS.register_module()
class Det3DVisualizationHook(Hook):
    """Detection Visualization Hook. Used to visualize validation and testing
    process prediction results.

    In the testing phase:

    1. If ``show`` is True, it means that only the prediction results are
        visualized without storing data, so ``vis_backends`` needs to
        be excluded.
    2. If ``test_out_dir`` is specified, it means that the prediction results
        need to be saved to ``test_out_dir``. In order to avoid vis_backends
        also storing data, so ``vis_backends`` needs to be excluded.
    3. ``vis_backends`` takes effect if the user does not specify ``show``
        and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or
        TensorboardVisBackend to store the prediction result in Wandb or
        Tensorboard.

    Args:
        draw (bool): whether to draw prediction results. If it is False,
            it means that no drawing will be done. Defaults to False.
        interval (int): The interval of visualization. Defaults to 50.
        score_thr (float): The threshold to visualize the bboxes
            and masks. Defaults to 0.3.
        show (bool): Whether to display the drawn image. Default to False.
43
        vis_task (str): Visualization task. Defaults to 'mono_det'.
44
45
46
        wait_time (float): The interval of show (s). Defaults to 0.
        test_out_dir (str, optional): directory where painted images
            will be saved in testing process.
47
48
        backend_args (dict, optional): Arguments to instantiate the
            corresponding backend. Defaults to None.
49
50
51
52
53
54
55
    """

    def __init__(self,
                 draw: bool = False,
                 interval: int = 50,
                 score_thr: float = 0.3,
                 show: bool = False,
56
                 vis_task: str = 'mono_det',
57
58
                 wait_time: float = 0.,
                 test_out_dir: Optional[str] = None,
59
                 backend_args: Optional[dict] = None):
60
61
62
63
64
65
66
67
68
69
70
        self._visualizer: Visualizer = Visualizer.get_current_instance()
        self.interval = interval
        self.score_thr = score_thr
        self.show = show
        if self.show:
            # No need to think about vis backends.
            self._visualizer._vis_backends = {}
            warnings.warn('The show is True, it means that only '
                          'the prediction results are visualized '
                          'without storing data, so vis_backends '
                          'needs to be excluded.')
71
        self.vis_task = vis_task
72
73

        self.wait_time = wait_time
74
        self.backend_args = backend_args
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        self.draw = draw
        self.test_out_dir = test_out_dir
        self._test_index = 0

    def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
                       outputs: Sequence[Det3DDataSample]) -> None:
        """Run after every ``self.interval`` validation iterations.

        Args:
            runner (:obj:`Runner`): The runner of the validation process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (dict): Data from dataloader.
            outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
                that contain annotations and predictions.
        """
        if self.draw is False:
            return

        # There is no guarantee that the same batch of images
        # is visualized for each evaluation.
        total_curr_iter = runner.iter + batch_idx

ChaimZhu's avatar
ChaimZhu committed
97
98
        data_input = dict()

99
        # Visualize only the first data
100
101
102
103
        if self.vis_task in [
                'mono_det', 'multi-view_det', 'multi-modality_det'
        ]:
            assert 'img_path' in outputs[0], 'img_path is not in outputs[0]'
ChaimZhu's avatar
ChaimZhu committed
104
            img_path = outputs[0].img_path
105
106
107
108
109
110
111
112
113
114
115
            if isinstance(img_path, list):
                img = []
                for single_img_path in img_path:
                    img_bytes = get(
                        single_img_path, backend_args=self.backend_args)
                    single_img = mmcv.imfrombytes(
                        img_bytes, channel_order='rgb')
                    img.append(single_img)
            else:
                img_bytes = get(img_path, backend_args=self.backend_args)
                img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
ChaimZhu's avatar
ChaimZhu committed
116
117
            data_input['img'] = img

118
119
120
        if self.vis_task in ['lidar_det', 'multi-modality_det', 'lidar_seg']:
            assert 'lidar_path' in outputs[
                0], 'lidar_path is not in outputs[0]'
ChaimZhu's avatar
ChaimZhu committed
121
122
            lidar_path = outputs[0].lidar_path
            num_pts_feats = outputs[0].num_pts_feats
123
            pts_bytes = get(lidar_path, backend_args=self.backend_args)
ChaimZhu's avatar
ChaimZhu committed
124
125
126
            points = np.frombuffer(pts_bytes, dtype=np.float32)
            points = points.reshape(-1, num_pts_feats)
            data_input['points'] = points
127
128
129

        if total_curr_iter % self.interval == 0:
            self._visualizer.add_datasample(
ChaimZhu's avatar
ChaimZhu committed
130
131
                'val sample',
                data_input,
132
133
                data_sample=outputs[0],
                show=self.show,
134
                vis_task=self.vis_task,
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                wait_time=self.wait_time,
                pred_score_thr=self.score_thr,
                step=total_curr_iter)

    def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
                        outputs: Sequence[Det3DDataSample]) -> None:
        """Run after every testing iterations.

        Args:
            runner (:obj:`Runner`): The runner of the testing process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (dict): Data from dataloader.
            outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
                that contain annotations and predictions.
        """
        if self.draw is False:
            return

        if self.test_out_dir is not None:
            self.test_out_dir = osp.join(runner.work_dir, runner.timestamp,
                                         self.test_out_dir)
            mkdir_or_exist(self.test_out_dir)

        for data_sample in outputs:
            self._test_index += 1

ChaimZhu's avatar
ChaimZhu committed
161
            data_input = dict()
162
163
164
165
166
167
168
169
170
171
            assert 'img_path' in data_sample or 'lidar_path' in data_sample, \
                "'data_sample' must contain 'img_path' or 'lidar_path'"

            out_file = o3d_save_path = None

            if self.vis_task in [
                    'mono_det', 'multi-view_det', 'multi-modality_det'
            ]:
                assert 'img_path' in data_sample, \
                    'img_path is not in data_sample'
ChaimZhu's avatar
ChaimZhu committed
172
                img_path = data_sample.img_path
173
174
175
176
177
178
179
180
181
182
183
                if isinstance(img_path, list):
                    img = []
                    for single_img_path in img_path:
                        img_bytes = get(
                            single_img_path, backend_args=self.backend_args)
                        single_img = mmcv.imfrombytes(
                            img_bytes, channel_order='rgb')
                        img.append(single_img)
                else:
                    img_bytes = get(img_path, backend_args=self.backend_args)
                    img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
ChaimZhu's avatar
ChaimZhu committed
184
                data_input['img'] = img
185
                if self.test_out_dir is not None:
186
187
                    if isinstance(img_path, list):
                        img_path = img_path[0]
188
189
190
191
192
193
194
195
                    out_file = osp.basename(img_path)
                    out_file = osp.join(self.test_out_dir, out_file)

            if self.vis_task in [
                    'lidar_det', 'multi-modality_det', 'lidar_seg'
            ]:
                assert 'lidar_path' in data_sample, \
                    'lidar_path is not in data_sample'
ChaimZhu's avatar
ChaimZhu committed
196
197
                lidar_path = data_sample.lidar_path
                num_pts_feats = data_sample.num_pts_feats
198
                pts_bytes = get(lidar_path, backend_args=self.backend_args)
ChaimZhu's avatar
ChaimZhu committed
199
200
201
                points = np.frombuffer(pts_bytes, dtype=np.float32)
                points = points.reshape(-1, num_pts_feats)
                data_input['points'] = points
202
203
204
205
                if self.test_out_dir is not None:
                    o3d_save_path = osp.basename(lidar_path).split(
                        '.')[0] + '.png'
                    o3d_save_path = osp.join(self.test_out_dir, o3d_save_path)
206
207

            self._visualizer.add_datasample(
ChaimZhu's avatar
ChaimZhu committed
208
209
                'test sample',
                data_input,
210
211
                data_sample=data_sample,
                show=self.show,
212
                vis_task=self.vis_task,
213
214
215
                wait_time=self.wait_time,
                pred_score_thr=self.score_thr,
                out_file=out_file,
216
                o3d_save_path=o3d_save_path,
217
                step=self._test_index)