visualization_hook.py 7.06 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional, Sequence

import mmcv
ChaimZhu's avatar
ChaimZhu committed
7
import numpy as np
8
from mmengine.fileio import get
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.utils import mkdir_or_exist
from mmengine.visualization import Visualizer

from mmdet3d.registry import HOOKS
from mmdet3d.structures import Det3DDataSample


@HOOKS.register_module()
class Det3DVisualizationHook(Hook):
    """Detection Visualization Hook. Used to visualize validation and testing
    process prediction results.

    In the testing phase:

    1. If ``show`` is True, it means that only the prediction results are
        visualized without storing data, so ``vis_backends`` needs to
        be excluded.
    2. If ``test_out_dir`` is specified, it means that the prediction results
        need to be saved to ``test_out_dir``. In order to avoid vis_backends
        also storing data, so ``vis_backends`` needs to be excluded.
    3. ``vis_backends`` takes effect if the user does not specify ``show``
        and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or
        TensorboardVisBackend to store the prediction result in Wandb or
        Tensorboard.

    Args:
        draw (bool): whether to draw prediction results. If it is False,
            it means that no drawing will be done. Defaults to False.
        interval (int): The interval of visualization. Defaults to 50.
        score_thr (float): The threshold to visualize the bboxes
            and masks. Defaults to 0.3.
        show (bool): Whether to display the drawn image. Default to False.
43
        vis_task (str): Visualization task. Defaults to 'mono_det'.
44
45
46
        wait_time (float): The interval of show (s). Defaults to 0.
        test_out_dir (str, optional): directory where painted images
            will be saved in testing process.
47
48
        backend_args (dict, optional): Arguments to instantiate the
            corresponding backend. Defaults to None.
49
50
51
52
53
54
55
    """

    def __init__(self,
                 draw: bool = False,
                 interval: int = 50,
                 score_thr: float = 0.3,
                 show: bool = False,
56
                 vis_task: str = 'mono_det',
57
58
                 wait_time: float = 0.,
                 test_out_dir: Optional[str] = None,
59
                 backend_args: Optional[dict] = None):
60
61
62
63
64
65
66
67
68
69
70
        self._visualizer: Visualizer = Visualizer.get_current_instance()
        self.interval = interval
        self.score_thr = score_thr
        self.show = show
        if self.show:
            # No need to think about vis backends.
            self._visualizer._vis_backends = {}
            warnings.warn('The show is True, it means that only '
                          'the prediction results are visualized '
                          'without storing data, so vis_backends '
                          'needs to be excluded.')
71
        self.vis_task = vis_task
72
73

        self.wait_time = wait_time
74
        self.backend_args = backend_args
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        self.draw = draw
        self.test_out_dir = test_out_dir
        self._test_index = 0

    def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
                       outputs: Sequence[Det3DDataSample]) -> None:
        """Run after every ``self.interval`` validation iterations.

        Args:
            runner (:obj:`Runner`): The runner of the validation process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (dict): Data from dataloader.
            outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
                that contain annotations and predictions.
        """
        if self.draw is False:
            return

        # There is no guarantee that the same batch of images
        # is visualized for each evaluation.
        total_curr_iter = runner.iter + batch_idx

ChaimZhu's avatar
ChaimZhu committed
97
98
        data_input = dict()

99
        # Visualize only the first data
ChaimZhu's avatar
ChaimZhu committed
100
101
        if 'img_path' in outputs[0]:
            img_path = outputs[0].img_path
102
            img_bytes = get(img_path, backend_args=self.backend_args)
ChaimZhu's avatar
ChaimZhu committed
103
104
105
106
107
108
            img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
            data_input['img'] = img

        if 'lidar_path' in outputs[0]:
            lidar_path = outputs[0].lidar_path
            num_pts_feats = outputs[0].num_pts_feats
109
            pts_bytes = get(lidar_path, backend_args=self.backend_args)
ChaimZhu's avatar
ChaimZhu committed
110
111
112
            points = np.frombuffer(pts_bytes, dtype=np.float32)
            points = points.reshape(-1, num_pts_feats)
            data_input['points'] = points
113
114
115

        if total_curr_iter % self.interval == 0:
            self._visualizer.add_datasample(
ChaimZhu's avatar
ChaimZhu committed
116
117
                'val sample',
                data_input,
118
119
                data_sample=outputs[0],
                show=self.show,
120
                vis_task=self.vis_task,
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
                wait_time=self.wait_time,
                pred_score_thr=self.score_thr,
                step=total_curr_iter)

    def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
                        outputs: Sequence[Det3DDataSample]) -> None:
        """Run after every testing iterations.

        Args:
            runner (:obj:`Runner`): The runner of the testing process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (dict): Data from dataloader.
            outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
                that contain annotations and predictions.
        """
        if self.draw is False:
            return

        if self.test_out_dir is not None:
            self.test_out_dir = osp.join(runner.work_dir, runner.timestamp,
                                         self.test_out_dir)
            mkdir_or_exist(self.test_out_dir)

        for data_sample in outputs:
            self._test_index += 1

ChaimZhu's avatar
ChaimZhu committed
147
148
149
            data_input = dict()
            if 'img_path' in data_sample:
                img_path = data_sample.img_path
150
                img_bytes = get(img_path, backend_args=self.backend_args)
ChaimZhu's avatar
ChaimZhu committed
151
152
153
154
155
156
                img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
                data_input['img'] = img

            if 'lidar_path' in data_sample:
                lidar_path = data_sample.lidar_path
                num_pts_feats = data_sample.num_pts_feats
157
                pts_bytes = get(lidar_path, backend_args=self.backend_args)
ChaimZhu's avatar
ChaimZhu committed
158
159
160
                points = np.frombuffer(pts_bytes, dtype=np.float32)
                points = points.reshape(-1, num_pts_feats)
                data_input['points'] = points
161
162
163

            out_file = None
            if self.test_out_dir is not None:
164
165
                out_file = osp.basename(img_path)
                out_file = osp.join(self.test_out_dir, out_file)
166
167

            self._visualizer.add_datasample(
ChaimZhu's avatar
ChaimZhu committed
168
169
                'test sample',
                data_input,
170
171
                data_sample=data_sample,
                show=self.show,
172
                vis_task=self.vis_task,
173
174
175
176
                wait_time=self.wait_time,
                pred_score_thr=self.score_thr,
                out_file=out_file,
                step=self._test_index)