local_visualizer.py 49 KB
Newer Older
ZCMax's avatar
ZCMax committed
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import copy
3
import math
4
import os
5
import sys
6
7
import time
from typing import List, Optional, Sequence, Tuple, Union
ZCMax's avatar
ZCMax committed
8

9
import matplotlib.pyplot as plt
ZCMax's avatar
ZCMax committed
10
11
import mmcv
import numpy as np
12
13
14
from matplotlib.collections import PatchCollection
from matplotlib.patches import PathPatch
from matplotlib.path import Path
15
from mmdet.visualization import DetLocalVisualizer, get_palette
ZCMax's avatar
ZCMax committed
16
from mmengine.dist import master_only
17
from mmengine.logging import print_log
18
from mmengine.structures import InstanceData
19
from mmengine.visualization import Visualizer as MMENGINE_Visualizer
20
21
from mmengine.visualization.utils import (check_type, color_val_matplotlib,
                                          tensor2ndarray)
ZCMax's avatar
ZCMax committed
22
23
from torch import Tensor

24
from mmdet3d.registry import VISUALIZERS
25
26
from mmdet3d.structures import (BaseInstance3DBoxes, Box3DMode,
                                CameraInstance3DBoxes, Coord3DMode,
27
28
29
                                DepthInstance3DBoxes, DepthPoints,
                                Det3DDataSample, LiDARInstance3DBoxes,
                                PointData, points_cam2img)
30
31
from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img,
                        proj_lidar_bbox3d_to_img, to_depth_mode)
zhangshilong's avatar
zhangshilong committed
32

ZCMax's avatar
ZCMax committed
33
34
35
try:
    import open3d as o3d
    from open3d import geometry
36
    from open3d.visualization import Visualizer
ZCMax's avatar
ZCMax committed
37
except ImportError:
38
    o3d = geometry = Visualizer = None
ZCMax's avatar
ZCMax committed
39
40
41
42
43
44
45
46
47
48


@VISUALIZERS.register_module()
class Det3DLocalVisualizer(DetLocalVisualizer):
    """MMDetection3D Local Visualizer.

    - 3D detection and segmentation drawing methods

      - draw_bboxes_3d: draw 3D bounding boxes on point clouds
      - draw_proj_bboxes_3d: draw projected 3D bounding boxes on image
zhangshilong's avatar
zhangshilong committed
49
      - draw_seg_mask: draw segmentation mask via per-point colorization
ZCMax's avatar
ZCMax committed
50
51
52

    Args:
        name (str): Name of the instance. Defaults to 'visualizer'.
53
54
55
        points (np.ndarray, optional): Points to visualize with shape (N, 3+C).
            Defaults to None.
        image (np.ndarray, optional): The origin image to draw. The format
ZCMax's avatar
ZCMax committed
56
            should be RGB. Defaults to None.
57
58
59
        pcd_mode (int): The point cloud mode (coordinates): 0 represents LiDAR,
            1 represents CAMERA, 2 represents Depth. Defaults to 0.
        vis_backends (List[dict], optional): Visual backend config list.
ZCMax's avatar
ZCMax committed
60
61
62
63
            Defaults to None.
        save_dir (str, optional): Save file dir for all storage backends.
            If it is None, the backend storage will not save any data.
            Defaults to None.
64
65
66
67
68
69
70
        bbox_color (str or Tuple[int], optional): Color of bbox lines.
            The tuple of color should be in BGR order. Defaults to None.
        text_color (str or Tuple[int]): Color of texts. The tuple of color
            should be in BGR order. Defaults to (200, 200, 200).
        mask_color (str or Tuple[int], optional): Color of masks. The tuple of
            color should be in BGR order. Defaults to None.
        line_width (int or float): The linewidth of lines. Defaults to 3.
71
        frame_cfg (dict): The coordinate frame config while Open3D
zhangshilong's avatar
zhangshilong committed
72
73
            visualization initialization.
            Defaults to dict(size=1, origin=[0, 0, 0]).
74
75
        alpha (int or float): The transparency of bboxes or mask.
            Defaults to 0.8.
76
77
        multi_imgs_col (int): The number of columns in arrangement when showing
            multi-view images.
ZCMax's avatar
ZCMax committed
78
79
80
81

    Examples:
        >>> import numpy as np
        >>> import torch
82
        >>> from mmengine.structures import InstanceData
83
84
        >>> from mmdet3d.structures import (DepthInstance3DBoxes
        ...                                 Det3DDataSample)
zhangshilong's avatar
zhangshilong committed
85
        >>> from mmdet3d.visualization import Det3DLocalVisualizer
ZCMax's avatar
ZCMax committed
86
87

        >>> det3d_local_visualizer = Det3DLocalVisualizer()
88
89
        >>> image = np.random.randint(0, 256, size=(10, 12, 3)).astype('uint8')
        >>> points = np.random.rand(1000, 3)
ZCMax's avatar
ZCMax committed
90
        >>> gt_instances_3d = InstanceData()
91
92
        >>> gt_instances_3d.bboxes_3d = DepthInstance3DBoxes(
        ...     torch.rand((5, 7)))
zhangshilong's avatar
zhangshilong committed
93
        >>> gt_instances_3d.labels_3d = torch.randint(0, 2, (5,))
ZCMax's avatar
ZCMax committed
94
95
        >>> gt_det3d_data_sample = Det3DDataSample()
        >>> gt_det3d_data_sample.gt_instances_3d = gt_instances_3d
zhangshilong's avatar
zhangshilong committed
96
97
        >>> data_input = dict(img=image, points=points)
        >>> det3d_local_visualizer.add_datasample('3D Scene', data_input,
98
99
100
101
102
103
104
105
106
107
108
109
110
        ...                                       gt_det3d_data_sample)

        >>> from mmdet3d.structures import PointData
        >>> det3d_local_visualizer = Det3DLocalVisualizer()
        >>> points = np.random.rand(1000, 3)
        >>> gt_pts_seg = PointData()
        >>> gt_pts_seg.pts_semantic_mask = torch.randint(0, 10, (1000, ))
        >>> gt_det3d_data_sample = Det3DDataSample()
        >>> gt_det3d_data_sample.gt_pts_seg = gt_pts_seg
        >>> data_input = dict(points=points)
        >>> det3d_local_visualizer.add_datasample('3D Scene', data_input,
        ...                                       gt_det3d_data_sample,
        ...                                       vis_task='lidar_seg')
ZCMax's avatar
ZCMax committed
111
112
    """

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def __init__(
        self,
        name: str = 'visualizer',
        points: Optional[np.ndarray] = None,
        image: Optional[np.ndarray] = None,
        pcd_mode: int = 0,
        vis_backends: Optional[List[dict]] = None,
        save_dir: Optional[str] = None,
        bbox_color: Optional[Union[str, Tuple[int]]] = None,
        text_color: Union[str, Tuple[int]] = (200, 200, 200),
        mask_color: Optional[Union[str, Tuple[int]]] = None,
        line_width: Union[int, float] = 3,
        frame_cfg: dict = dict(size=1, origin=[0, 0, 0]),
        alpha: Union[int, float] = 0.8,
        multi_imgs_col: int = 3,
        fig_show_cfg: dict = dict(figsize=(18, 12))
    ) -> None:
ZCMax's avatar
ZCMax committed
130
131
132
133
134
135
136
137
138
139
        super().__init__(
            name=name,
            image=image,
            vis_backends=vis_backends,
            save_dir=save_dir,
            bbox_color=bbox_color,
            text_color=text_color,
            mask_color=mask_color,
            line_width=line_width,
            alpha=alpha)
140
141
        if points is not None:
            self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg)
142
143
        self.multi_imgs_col = multi_imgs_col
        self.fig_show_cfg.update(fig_show_cfg)
ZCMax's avatar
ZCMax committed
144

145
146
147
148
        self.flag_pause = False
        self.flag_next = False
        self.flag_exit = False

149
150
151
152
153
154
    def _clear_o3d_vis(self) -> None:
        """Clear open3d vis."""

        if hasattr(self, 'o3d_vis'):
            del self.o3d_vis
            del self.points_colors
155
156
157
            del self.view_control
            if hasattr(self, 'pcd'):
                del self.pcd
158

159
    def _initialize_o3d_vis(self, show=True) -> Visualizer:
160
        """Initialize open3d vis according to frame_cfg.
ZCMax's avatar
ZCMax committed
161
162

        Args:
163
164
            frame_cfg (dict): The config to create coordinate frame in open3d
                vis.
ZCMax's avatar
ZCMax committed
165
166

        Returns:
167
            :obj:`o3d.visualization.Visualizer`: Created open3d vis.
ZCMax's avatar
ZCMax committed
168
        """
169
170
171
        if o3d is None or geometry is None:
            raise ImportError(
                'Please run "pip install open3d" to install open3d first.')
172
173
174
175
176
177
178
179
        glfw_key_escape = 256  # Esc
        glfw_key_space = 32  # Space
        glfw_key_right = 262  # Right
        o3d_vis = o3d.visualization.VisualizerWithKeyCallback()
        o3d_vis.register_key_callback(glfw_key_escape, self.escape_callback)
        o3d_vis.register_key_action_callback(glfw_key_space,
                                             self.space_action_callback)
        o3d_vis.register_key_callback(glfw_key_right, self.right_callback)
180
181
182
        if os.environ.get('DISPLAY', None) is not None and show:
            o3d_vis.create_window()
            self.view_control = o3d_vis.get_view_control()
ZCMax's avatar
ZCMax committed
183
184
185
186
187
        return o3d_vis

    @master_only
    def set_points(self,
                   points: np.ndarray,
188
                   pcd_mode: int = 0,
189
                   vis_mode: str = 'replace',
190
                   frame_cfg: dict = dict(size=1, origin=[0, 0, 0]),
191
                   points_color: Tuple[float] = (0.8, 0.8, 0.8),
ZCMax's avatar
ZCMax committed
192
193
                   points_size: int = 2,
                   mode: str = 'xyz') -> None:
194
        """Set the point cloud to draw.
ZCMax's avatar
ZCMax committed
195
196

        Args:
197
198
199
            points (np.ndarray): Points to visualize with shape (N, 3+C).
            pcd_mode (int): The point cloud mode (coordinates): 0 represents
                LiDAR, 1 represents CAMERA, 2 represents Depth. Defaults to 0.
200
            vis_mode (str): The visualization mode in Open3D:
201
202
203
204
205

                - 'replace': Replace the existing point cloud with input point
                  cloud.
                - 'add': Add input point cloud into existing point cloud.

206
                Defaults to 'replace'.
207
            frame_cfg (dict): The coordinate frame config for Open3D
208
209
                visualization initialization.
                Defaults to dict(size=1, origin=[0, 0, 0]).
210
            points_color (Tuple[float]): The color of points.
211
                Defaults to (1, 1, 1).
212
213
214
215
            points_size (int): The size of points to show on visualizer.
                Defaults to 2.
            mode (str): Indicate type of the input points, available mode
                ['xyz', 'xyzrgb']. Defaults to 'xyz'.
ZCMax's avatar
ZCMax committed
216
217
        """
        assert points is not None
218
        assert vis_mode in ('replace', 'add')
ZCMax's avatar
ZCMax committed
219
220
        check_type('points', points, np.ndarray)

221
        if not hasattr(self, 'o3d_vis'):
222
            self.o3d_vis = self._initialize_o3d_vis()
223

224
225
226
227
        # for now we convert points into depth mode for visualization
        if pcd_mode != Coord3DMode.DEPTH:
            points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH)

228
        if hasattr(self, 'pcd') and vis_mode != 'add':
ZCMax's avatar
ZCMax committed
229
230
231
            self.o3d_vis.remove_geometry(self.pcd)

        # set points size in Open3D
232
233
234
235
        render_option = self.o3d_vis.get_render_option()
        if render_option is not None:
            render_option.point_size = points_size
            render_option.background_color = np.asarray([0, 0, 0])
ZCMax's avatar
ZCMax committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

        points = points.copy()
        pcd = geometry.PointCloud()
        if mode == 'xyz':
            pcd.points = o3d.utility.Vector3dVector(points[:, :3])
            points_colors = np.tile(
                np.array(points_color), (points.shape[0], 1))
        elif mode == 'xyzrgb':
            pcd.points = o3d.utility.Vector3dVector(points[:, :3])
            points_colors = points[:, 3:6]
            # normalize to [0, 1] for Open3D drawing
            if not ((points_colors >= 0.0) & (points_colors <= 1.0)).all():
                points_colors /= 255.0
        else:
            raise NotImplementedError

252
253
254
255
        # create coordinate frame
        mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg)
        self.o3d_vis.add_geometry(mesh_frame)

ZCMax's avatar
ZCMax committed
256
257
258
        pcd.colors = o3d.utility.Vector3dVector(points_colors)
        self.o3d_vis.add_geometry(pcd)
        self.pcd = pcd
259
        self.points_colors = points_colors
ZCMax's avatar
ZCMax committed
260
261
262
263
264

    # TODO: assign 3D Box color according to pred / GT labels
    # We draw GT / pred bboxes on the same point cloud scenes
    # for better detection performance comparison
    def draw_bboxes_3d(self,
265
                       bboxes_3d: BaseInstance3DBoxes,
266
267
268
269
270
                       bbox_color: Tuple[float] = (0, 1, 0),
                       points_in_box_color: Tuple[float] = (1, 0, 0),
                       rot_axis: int = 2,
                       center_mode: str = 'lidar_bottom',
                       mode: str = 'xyz') -> None:
ZCMax's avatar
ZCMax committed
271
272
273
274
        """Draw bbox on visualizer and change the color of points inside
        bbox3d.

        Args:
275
276
277
278
279
280
281
282
283
284
285
286
            bboxes_3d (:obj:`BaseInstance3DBoxes`): 3D bbox
                (x, y, z, x_size, y_size, z_size, yaw) to visualize.
            bbox_color (Tuple[float]): The color of 3D bboxes.
                Defaults to (0, 1, 0).
            points_in_box_color (Tuple[float]): The color of points inside 3D
                bboxes. Defaults to (1, 0, 0).
            rot_axis (int): Rotation axis of 3D bboxes. Defaults to 2.
            center_mode (str): Indicates the center of bbox is bottom center or
                gravity center. Available mode
                ['lidar_bottom', 'camera_bottom']. Defaults to 'lidar_bottom'.
            mode (str): Indicates the type of input points, available mode
                ['xyz', 'xyzrgb']. Defaults to 'xyz'.
ZCMax's avatar
ZCMax committed
287
288
289
        """
        # Before visualizing the 3D Boxes in point cloud scene
        # we need to convert the boxes to Depth mode
290
291
292
293
        check_type('bboxes', bboxes_3d, BaseInstance3DBoxes)

        if not isinstance(bboxes_3d, DepthInstance3DBoxes):
            bboxes_3d = bboxes_3d.convert_to(Box3DMode.DEPTH)
ZCMax's avatar
ZCMax committed
294
295
296
297

        # convert bboxes to numpy dtype
        bboxes_3d = tensor2ndarray(bboxes_3d.tensor)

298
        # in_box_color = np.array(points_in_box_color)
ZCMax's avatar
ZCMax committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

        for i in range(len(bboxes_3d)):
            center = bboxes_3d[i, 0:3]
            dim = bboxes_3d[i, 3:6]
            yaw = np.zeros(3)
            yaw[rot_axis] = bboxes_3d[i, 6]
            rot_mat = geometry.get_rotation_matrix_from_xyz(yaw)

            if center_mode == 'lidar_bottom':
                # bottom center to gravity center
                center[rot_axis] += dim[rot_axis] / 2
            elif center_mode == 'camera_bottom':
                # bottom center to gravity center
                center[rot_axis] -= dim[rot_axis] / 2
            box3d = geometry.OrientedBoundingBox(center, rot_mat, dim)

            line_set = geometry.LineSet.create_from_oriented_bounding_box(
                box3d)
317
            line_set.paint_uniform_color(np.array(bbox_color[i]) / 255.)
ZCMax's avatar
ZCMax committed
318
319
320
321
322
323
324
            # draw bboxes on visualizer
            self.o3d_vis.add_geometry(line_set)

            # change the color of points which are in box
            if self.pcd is not None and mode == 'xyz':
                indices = box3d.get_point_indices_within_bounding_box(
                    self.pcd.points)
325
                self.points_colors[indices] = np.array(bbox_color[i]) / 255.
ZCMax's avatar
ZCMax committed
326
327
328
329
330
331

        # update points colors
        if self.pcd is not None:
            self.pcd.colors = o3d.utility.Vector3dVector(self.points_colors)
            self.o3d_vis.update_geometry(self.pcd)

332
333
    def set_bev_image(self,
                      bev_image: Optional[np.ndarray] = None,
334
                      bev_shape: int = 900) -> None:
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        """Set the bev image to draw.

        Args:
            bev_image (np.ndarray, optional): The bev image to draw.
                Defaults to None.
            bev_shape (int): The bev image shape. Defaults to 900.
        """
        if bev_image is None:
            bev_image = np.zeros((bev_shape, bev_shape, 3), np.uint8)

        self._image = bev_image
        self.width, self.height = bev_image.shape[1], bev_image.shape[0]
        self._default_font_size = max(
            np.sqrt(self.height * self.width) // 90, 10)
        self.ax_save.cla()
        self.ax_save.axis(False)
        self.ax_save.imshow(bev_image, origin='lower')
        # plot camera view range
        x1 = np.linspace(0, self.width / 2)
        x2 = np.linspace(self.width / 2, self.width)
        self.ax_save.plot(
            x1,
            self.width / 2 - x1,
            ls='--',
            color='grey',
            linewidth=1,
            alpha=0.5)
        self.ax_save.plot(
            x2,
            x2 - self.width / 2,
            ls='--',
            color='grey',
            linewidth=1,
            alpha=0.5)
        self.ax_save.plot(
            self.width / 2,
            0,
            marker='+',
            markersize=16,
            markeredgecolor='red')

    # TODO: Support bev point cloud visualization
    @master_only
    def draw_bev_bboxes(self,
                        bboxes_3d: BaseInstance3DBoxes,
                        scale: int = 15,
381
382
                        edge_colors: Union[str, Tuple[int],
                                           List[Union[str, Tuple[int]]]] = 'o',
383
                        line_styles: Union[str, List[str]] = '-',
384
385
386
387
388
389
                        line_widths: Union[int, float, List[Union[int,
                                                                  float]]] = 1,
                        face_colors: Union[str, Tuple[int],
                                           List[Union[str,
                                                      Tuple[int]]]] = 'none',
                        alpha: Union[int, float] = 1) -> MMENGINE_Visualizer:
390
391
392
        """Draw projected 3D boxes on the image.

        Args:
393
394
            bboxes_3d (:obj:`BaseInstance3DBoxes`): 3D bbox
                (x, y, z, x_size, y_size, z_size, yaw) to visualize.
395
396
            scale (dict): Value to scale the bev bboxes for better
                visualization. Defaults to 15.
397
398
            edge_colors (str or Tuple[int] or List[str or Tuple[int]]):
                The colors of bboxes. ``colors`` can have the same length with
399
400
401
402
                lines or just single value. If ``colors`` is single value, all
                the lines will have the same colors. Refer to `matplotlib.
                colors` for full list of formats that are accepted.
                Defaults to 'o'.
403
404
405
406
            line_styles (str or List[str]): The linestyle of lines.
                ``line_styles`` can have the same length with texts or just
                single value. If ``line_styles`` is single value, all the lines
                will have the same linestyle. Reference to
407
408
                https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
                for more details. Defaults to '-'.
409
410
411
412
413
414
415
            line_widths (int or float or List[int or float]): The linewidth of
                lines. ``line_widths`` can have the same length with lines or
                just single value. If ``line_widths`` is single value, all the
                lines will have the same linewidth. Defaults to 2.
            face_colors (str or Tuple[int] or List[str or Tuple[int]]):
                The face colors. Defaults to 'none'.
            alpha (int or float): The transparency of bboxes. Defaults to 1.
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        """

        check_type('bboxes', bboxes_3d, BaseInstance3DBoxes)
        bev_bboxes = tensor2ndarray(bboxes_3d.bev)
        # scale the bev bboxes for better visualization
        bev_bboxes[:, :4] *= scale
        ctr, w, h, theta = np.split(bev_bboxes, [2, 3, 4], axis=-1)
        cos_value, sin_value = np.cos(theta), np.sin(theta)
        vec1 = np.concatenate([w / 2 * cos_value, w / 2 * sin_value], axis=-1)
        vec2 = np.concatenate([-h / 2 * sin_value, h / 2 * cos_value], axis=-1)
        pt1 = ctr + vec1 + vec2
        pt2 = ctr + vec1 - vec2
        pt3 = ctr - vec1 - vec2
        pt4 = ctr - vec1 + vec2
        poly = np.stack([pt1, pt2, pt3, pt4], axis=-2)
        # move the object along x-axis
        poly[:, :, 0] += self.width / 2
        poly = [p for p in poly]
        return self.draw_polygons(
            poly,
            alpha=alpha,
            edge_colors=edge_colors,
            line_styles=line_styles,
            line_widths=line_widths,
            face_colors=face_colors)

    @master_only
443
444
445
    def draw_points_on_image(self,
                             points: Union[np.ndarray, Tensor],
                             pts2img: np.ndarray,
446
447
                             sizes: Union[np.ndarray, int] = 3,
                             max_depth: Optional[float] = None) -> None:
448
449
450
        """Draw projected points on the image.

        Args:
451
452
453
454
            points (np.ndarray or Tensor): Points to draw.
            pts2img (np.ndarray): The transformation matrix from the coordinate
                of point cloud to image plane.
            sizes (np.ndarray or int): The marker size. Defaults to 10.
455
456
            max_depth (float): The max depth in the color map. Defaults to
                None.
457
458
459
460
461
462
        """
        check_type('points', points, (np.ndarray, Tensor))
        points = tensor2ndarray(points)
        assert self._image is not None, 'Please set image using `set_image`'
        projected_points = points_cam2img(points, pts2img, with_depth=True)
        depths = projected_points[:, 2]
463
464
465
466
        # Show depth adaptively consideing different scenes
        if max_depth is None:
            max_depth = depths.max()
        colors = (depths % max_depth) / max_depth
467
468
469
470
471
472
473
474
        # use colormap to obtain the render color
        color_map = plt.get_cmap('jet')
        self.ax_save.scatter(
            projected_points[:, 0],
            projected_points[:, 1],
            c=colors,
            cmap=color_map,
            s=sizes,
475
            alpha=0.7,
476
477
            edgecolors='none')

478
    # TODO: set bbox color according to palette
479
    @master_only
480
481
482
483
484
485
486
487
488
489
    def draw_proj_bboxes_3d(
            self,
            bboxes_3d: BaseInstance3DBoxes,
            input_meta: dict,
            edge_colors: Union[str, Tuple[int],
                               List[Union[str, Tuple[int]]]] = 'royalblue',
            line_styles: Union[str, List[str]] = '-',
            line_widths: Union[int, float, List[Union[int, float]]] = 2,
            face_colors: Union[str, Tuple[int],
                               List[Union[str, Tuple[int]]]] = 'royalblue',
490
491
            alpha: Union[int, float] = 0.4,
            img_size: Optional[Tuple] = None):
ZCMax's avatar
ZCMax committed
492
493
494
        """Draw projected 3D boxes on the image.

        Args:
495
496
            bboxes_3d (:obj:`BaseInstance3DBoxes`): 3D bbox
                (x, y, z, x_size, y_size, z_size, yaw) to visualize.
ZCMax's avatar
ZCMax committed
497
            input_meta (dict): Input meta information.
498
499
            edge_colors (str or Tuple[int] or List[str or Tuple[int]]):
                The colors of bboxes. ``colors`` can have the same length with
500
501
502
503
                lines or just single value. If ``colors`` is single value, all
                the lines will have the same colors. Refer to `matplotlib.
                colors` for full list of formats that are accepted.
                Defaults to 'royalblue'.
504
505
506
507
            line_styles (str or List[str]): The linestyle of lines.
                ``line_styles`` can have the same length with texts or just
                single value. If ``line_styles`` is single value, all the lines
                will have the same linestyle. Reference to
508
509
                https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
                for more details. Defaults to '-'.
510
511
512
513
514
515
516
            line_widths (int or float or List[int or float]): The linewidth of
                lines. ``line_widths`` can have the same length with lines or
                just single value. If ``line_widths`` is single value, all the
                lines will have the same linewidth. Defaults to 2.
            face_colors (str or Tuple[int] or List[str or Tuple[int]]):
                The face colors. Defaults to 'royalblue'.
            alpha (int or float): The transparency of bboxes. Defaults to 0.4.
517
            img_size (tuple, optional): The size (w, h) of the image.
ZCMax's avatar
ZCMax committed
518
519
520
521
        """

        check_type('bboxes', bboxes_3d, BaseInstance3DBoxes)

522
        if isinstance(bboxes_3d, DepthInstance3DBoxes):
ZCMax's avatar
ZCMax committed
523
            proj_bbox3d_to_img = proj_depth_bbox3d_to_img
524
        elif isinstance(bboxes_3d, LiDARInstance3DBoxes):
ZCMax's avatar
ZCMax committed
525
            proj_bbox3d_to_img = proj_lidar_bbox3d_to_img
526
        elif isinstance(bboxes_3d, CameraInstance3DBoxes):
ZCMax's avatar
ZCMax committed
527
528
            proj_bbox3d_to_img = proj_camera_bbox3d_to_img
        else:
529
            raise NotImplementedError('unsupported box type!')
ZCMax's avatar
ZCMax committed
530

531
532
        edge_colors_norm = color_val_matplotlib(edge_colors)

533
        corners_2d = proj_bbox3d_to_img(bboxes_3d, input_meta)
534
535
536
537
538
539
540
541
        if img_size is not None:
            # Filter out the bbox where half of stuff is outside the image.
            # This is for the visualization of multi-view image.
            valid_point_idx = (corners_2d[..., 0] >= 0) & \
                        (corners_2d[..., 0] <= img_size[0]) & \
                        (corners_2d[..., 1] >= 0) & (corners_2d[..., 1] <= img_size[1])  # noqa: E501
            valid_bbox_idx = valid_point_idx.sum(axis=-1) >= 4
            corners_2d = corners_2d[valid_bbox_idx]
542
543
544
545
546
547
548
549
            filter_edge_colors = []
            filter_edge_colors_norm = []
            for i, color in enumerate(edge_colors):
                if valid_bbox_idx[i]:
                    filter_edge_colors.append(color)
                    filter_edge_colors_norm.append(edge_colors_norm[i])
            edge_colors = filter_edge_colors
            edge_colors_norm = filter_edge_colors_norm
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564

        lines_verts_idx = [0, 1, 2, 3, 7, 6, 5, 4, 0, 3, 7, 4, 5, 1, 2, 6]
        lines_verts = corners_2d[:, lines_verts_idx, :]
        front_polys = corners_2d[:, 4:, :]
        codes = [Path.LINETO] * lines_verts.shape[1]
        codes[0] = Path.MOVETO
        pathpatches = []
        for i in range(len(corners_2d)):
            verts = lines_verts[i]
            pth = Path(verts, codes)
            pathpatches.append(PathPatch(pth))

        p = PatchCollection(
            pathpatches,
            facecolors='none',
565
            edgecolors=edge_colors_norm,
566
567
568
569
570
571
572
573
574
575
576
577
578
            linewidths=line_widths,
            linestyles=line_styles)

        self.ax_save.add_collection(p)

        # draw a mask on the front of project bboxes
        front_polys = [front_poly for front_poly in front_polys]
        return self.draw_polygons(
            front_polys,
            alpha=alpha,
            edge_colors=edge_colors,
            line_styles=line_styles,
            line_widths=line_widths,
579
            face_colors=edge_colors)
ZCMax's avatar
ZCMax committed
580

581
    @master_only
582
    def draw_seg_mask(self, seg_mask_colors: np.ndarray) -> None:
ZCMax's avatar
ZCMax committed
583
584
585
        """Add segmentation mask to visualizer via per-point colorization.

        Args:
586
587
588
            seg_mask_colors (np.ndarray): The segmentation mask with shape
                (N, 6), whose first 3 dims are point coordinates and last 3
                dims are converted colors.
ZCMax's avatar
ZCMax committed
589
590
591
592
        """
        # we can't draw the colors on existing points
        # in case gt and pred mask would overlap
        # instead we set a large offset along x-axis for each seg mask
593
594
595
596
597
598
599
600
601
        if hasattr(self, 'pcd'):
            offset = (np.array(self.pcd.points).max(0) -
                      np.array(self.pcd.points).min(0))[0] * 1.2
            mesh_frame = geometry.TriangleMesh.create_coordinate_frame(
                size=1, origin=[offset, 0,
                                0])  # create coordinate frame for seg
            self.o3d_vis.add_geometry(mesh_frame)
        else:
            offset = 0
ZCMax's avatar
ZCMax committed
602
603
        seg_points = copy.deepcopy(seg_mask_colors)
        seg_points[:, 0] += offset
604
        self.set_points(seg_points, pcd_mode=2, vis_mode='add', mode='xyzrgb')
ZCMax's avatar
ZCMax committed
605

606
607
608
609
610
    def _draw_instances_3d(self,
                           data_input: dict,
                           instances: InstanceData,
                           input_meta: dict,
                           vis_task: str,
611
                           show_pcd_rgb: bool = False,
612
                           palette: Optional[List[tuple]] = None) -> dict:
ZCMax's avatar
ZCMax committed
613
614
615
616
        """Draw 3D instances of GT or prediction.

        Args:
            data_input (dict): The input dict to draw.
617
618
619
620
621
            instances (:obj:`InstanceData`): Data structure for instance-level
                annotations or predictions.
            input_meta (dict): Meta information.
            vis_task (str): Visualization task, it includes: 'lidar_det',
                'multi-modality_det', 'mono_det'.
622
            show_pcd_rgb (bool): Whether to show RGB point cloud.
623
624
            palette (List[tuple], optional): Palette information corresponding
                to the category. Defaults to None.
ZCMax's avatar
ZCMax committed
625
626

        Returns:
627
            dict: The drawn point cloud and image whose channel is RGB.
ZCMax's avatar
ZCMax committed
628
629
        """

630
631
632
633
        # Only visualize when there is at least one instance
        if not len(instances) > 0:
            return None

ZCMax's avatar
ZCMax committed
634
        bboxes_3d = instances.bboxes_3d  # BaseInstance3DBoxes
635
        labels_3d = instances.labels_3d
ZCMax's avatar
ZCMax committed
636

637
        data_3d = dict()
ZCMax's avatar
ZCMax committed
638

639
        if vis_task in ['lidar_det', 'multi-modality_det']:
ZCMax's avatar
ZCMax committed
640
641
642
643
644
645
646
647
648
649
            assert 'points' in data_input
            points = data_input['points']
            check_type('points', points, (np.ndarray, Tensor))
            points = tensor2ndarray(points)

            if not isinstance(bboxes_3d, DepthInstance3DBoxes):
                points, bboxes_3d_depth = to_depth_mode(points, bboxes_3d)
            else:
                bboxes_3d_depth = bboxes_3d.clone()

650
651
652
653
654
655
656
657
            if 'axis_align_matrix' in input_meta:
                points = DepthPoints(points, points_dim=points.shape[1])
                rot_mat = input_meta['axis_align_matrix'][:3, :3]
                trans_vec = input_meta['axis_align_matrix'][:3, -1]
                points.rotate(rot_mat.T)
                points.translate(trans_vec)
                points = tensor2ndarray(points.tensor)

658
659
660
661
662
663
            max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
            bbox_color = palette if self.bbox_color is None \
                else self.bbox_color
            bbox_palette = get_palette(bbox_color, max_label + 1)
            colors = [bbox_palette[label] for label in labels_3d]

664
665
            self.set_points(
                points, pcd_mode=2, mode='xyzrgb' if show_pcd_rgb else 'xyz')
666
            self.draw_bboxes_3d(bboxes_3d_depth, bbox_color=colors)
ZCMax's avatar
ZCMax committed
667

668
669
            data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
            data_3d['points'] = points
ZCMax's avatar
ZCMax committed
670

671
        if vis_task in ['mono_det', 'multi-modality_det']:
ZCMax's avatar
ZCMax committed
672
            assert 'img' in data_input
673
            img = data_input['img']
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
            if isinstance(img, list) or (isinstance(img, (np.ndarray, Tensor))
                                         and len(img.shape) == 4):
                # show multi-view images
                img_size = img[0].shape[:2] if isinstance(
                    img, list) else img.shape[-2:]  # noqa: E501
                img_col = self.multi_imgs_col
                img_row = math.ceil(len(img) / img_col)
                composed_img = np.zeros(
                    (img_size[0] * img_row, img_size[1] * img_col, 3),
                    dtype=np.uint8)
                for i, single_img in enumerate(img):
                    # Note that we should keep the same order of elements both
                    # in `img` and `input_meta`
                    if isinstance(single_img, Tensor):
                        single_img = single_img.permute(1, 2, 0).numpy()
                        single_img = single_img[..., [2, 1, 0]]  # bgr to rgb
                    self.set_image(single_img)
                    single_img_meta = dict()
                    for key, meta in input_meta.items():
                        if isinstance(meta,
                                      (Sequence, np.ndarray,
                                       Tensor)) and len(meta) == len(img):
                            single_img_meta[key] = meta[i]
                        else:
                            single_img_meta[key] = meta
699
700
701
702
703
704
705
706

                    max_label = int(
                        max(labels_3d) if len(labels_3d) > 0 else 0)
                    bbox_color = palette if self.bbox_color is None \
                        else self.bbox_color
                    bbox_palette = get_palette(bbox_color, max_label + 1)
                    colors = [bbox_palette[label] for label in labels_3d]

707
708
709
                    self.draw_proj_bboxes_3d(
                        bboxes_3d,
                        single_img_meta,
710
711
                        img_size=single_img.shape[:2][::-1],
                        edge_colors=colors)
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
                    if vis_task == 'mono_det' and hasattr(
                            instances, 'centers_2d'):
                        centers_2d = instances.centers_2d
                        self.draw_points(centers_2d)
                    composed_img[(i // img_col) *
                                 img_size[0]:(i // img_col + 1) * img_size[0],
                                 (i % img_col) *
                                 img_size[1]:(i % img_col + 1) *
                                 img_size[1]] = self.get_image()
                data_3d['img'] = composed_img
            else:
                # show single-view image
                # TODO: Solve the problem: some line segments of 3d bboxes are
                # out of image by a large margin
                if isinstance(data_input['img'], Tensor):
                    img = img.permute(1, 2, 0).numpy()
                    img = img[..., [2, 1, 0]]  # bgr to rgb
                self.set_image(img)
730
731
732
733
734
735
736
737
738

                max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
                bbox_color = palette if self.bbox_color is None \
                    else self.bbox_color
                bbox_palette = get_palette(bbox_color, max_label + 1)
                colors = [bbox_palette[label] for label in labels_3d]

                self.draw_proj_bboxes_3d(
                    bboxes_3d, input_meta, edge_colors=colors)
739
740
741
742
743
                if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'):
                    centers_2d = instances.centers_2d
                    self.draw_points(centers_2d)
                drawn_img = self.get_image()
                data_3d['img'] = drawn_img
ZCMax's avatar
ZCMax committed
744
745
746
747

        return data_3d

    def _draw_pts_sem_seg(self,
748
                          points: Union[Tensor, np.ndarray],
zhangshilong's avatar
zhangshilong committed
749
                          pts_seg: PointData,
ZCMax's avatar
ZCMax committed
750
                          palette: Optional[List[tuple]] = None,
751
                          keep_index: Optional[int] = None) -> None:
752
753
754
        """Draw 3D semantic mask of GT or prediction.

        Args:
755
756
757
758
759
760
            points (Tensor or np.ndarray): The input point cloud to draw.
            pts_seg (:obj:`PointData`): Data structure for pixel-level
                annotations or predictions.
            palette (List[tuple], optional): Palette information corresponding
                to the category. Defaults to None.
            ignore_index (int, optional): Ignore category. Defaults to None.
761
        """
ZCMax's avatar
ZCMax committed
762
763
764
765
        check_type('points', points, (np.ndarray, Tensor))

        points = tensor2ndarray(points)
        pts_sem_seg = tensor2ndarray(pts_seg.pts_semantic_mask)
766
        palette = np.array(palette)
ZCMax's avatar
ZCMax committed
767

768
769
770
771
        if keep_index is not None:
            keep_index = tensor2ndarray(keep_index)
            points = points[keep_index]
            pts_sem_seg = pts_sem_seg[keep_index]
ZCMax's avatar
ZCMax committed
772
773
774
775

        pts_color = palette[pts_sem_seg]
        seg_color = np.concatenate([points[:, :3], pts_color], axis=1)

776
        self.draw_seg_mask(seg_color)
ZCMax's avatar
ZCMax committed
777
778
779

    @master_only
    def show(self,
780
             save_path: Optional[str] = None,
ZCMax's avatar
ZCMax committed
781
782
783
             drawn_img_3d: Optional[np.ndarray] = None,
             drawn_img: Optional[np.ndarray] = None,
             win_name: str = 'image',
784
785
             wait_time: int = -1,
             continue_key: str = 'right',
786
             vis_task: str = 'lidar_det') -> None:
787
        """Show the drawn point cloud/image.
ZCMax's avatar
ZCMax committed
788
789

        Args:
790
            save_path (str, optional): Path to save open3d visualized results.
791
792
793
794
                Defaults to None.
            drawn_img_3d (np.ndarray, optional): The image to show. If
                drawn_img_3d is not None, it will show the image got by
                Visualizer. Defaults to None.
ZCMax's avatar
ZCMax committed
795
            drawn_img (np.ndarray, optional): The image to show. If drawn_img
796
797
798
799
800
801
                is not None, it will show the image got by Visualizer.
                Defaults to None.
            win_name (str): The image title. Defaults to 'image'.
            wait_time (int): Delay in milliseconds. 0 is the special value that
                means "forever". Defaults to 0.
            continue_key (str): The key for users to continue. Defaults to ' '.
ZCMax's avatar
ZCMax committed
802
        """
803
804
805
806
807
808
809

        # In order to show multi-modal results at the same time, we show image
        # firstly and then show point cloud since the running of
        # Open3D will block the process
        if hasattr(self, '_image'):
            if drawn_img is None and drawn_img_3d is None:
                # use the image got by Visualizer.get_image()
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
                if vis_task == 'multi-modality_det':
                    import matplotlib.pyplot as plt
                    is_inline = 'inline' in plt.get_backend()
                    img = self.get_image() if drawn_img is None else drawn_img
                    self._init_manager(win_name)
                    fig = self.manager.canvas.figure
                    # remove white edges by set subplot margin
                    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
                    fig.clear()
                    ax = fig.add_subplot()
                    ax.axis(False)
                    ax.imshow(img)
                    self.manager.canvas.draw()
                    if is_inline:
                        return fig
                    else:
                        fig.show()
                    self.manager.canvas.flush_events()
                else:
                    super().show(drawn_img_3d, win_name, wait_time,
830
                                 continue_key)
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
            else:
                if vis_task == 'multi-modality_det':
                    import matplotlib.pyplot as plt
                    is_inline = 'inline' in plt.get_backend()
                    img = drawn_img if drawn_img_3d is None else drawn_img_3d
                    self._init_manager(win_name)
                    fig = self.manager.canvas.figure
                    # remove white edges by set subplot margin
                    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
                    fig.clear()
                    ax = fig.add_subplot()
                    ax.axis(False)
                    ax.imshow(img)
                    self.manager.canvas.draw()
                    if is_inline:
                        return fig
                    else:
                        fig.show()
                    self.manager.canvas.flush_events()
                else:
                    if drawn_img_3d is not None:
                        super().show(drawn_img_3d, win_name, wait_time,
                                     continue_key)
                    if drawn_img is not None:
                        super().show(drawn_img, win_name, wait_time,
                                     continue_key)
857

858
        if hasattr(self, 'o3d_vis'):
859
860
861
862
            if hasattr(self, 'view_port'):
                self.view_control.convert_from_pinhole_camera_parameters(
                    self.view_port)
            self.flag_exit = not self.o3d_vis.poll_events()
863
            self.o3d_vis.update_renderer()
864
865
866
            # if not hasattr(self, 'view_control'):
            #     self.o3d_vis.create_window()
            #     self.view_control = self.o3d_vis.get_view_control()
867
868
869
870
871
872
873
874
875
876
877
878
879
880
            self.view_port = \
                self.view_control.convert_to_pinhole_camera_parameters()  # noqa: E501
            if wait_time != -1:
                self.last_time = time.time()
                while time.time(
                ) - self.last_time < wait_time and self.o3d_vis.poll_events():
                    self.o3d_vis.update_renderer()
                    self.view_port = \
                        self.view_control.convert_to_pinhole_camera_parameters()  # noqa: E501
                while self.flag_pause and self.o3d_vis.poll_events():
                    self.o3d_vis.update_renderer()
                    self.view_port = \
                        self.view_control.convert_to_pinhole_camera_parameters()  # noqa: E501

881
            else:
882
883
884
885
886
887
888
889
                while not self.flag_next and self.o3d_vis.poll_events():
                    self.o3d_vis.update_renderer()
                    self.view_port = \
                        self.view_control.convert_to_pinhole_camera_parameters()  # noqa: E501
                self.flag_next = False
            self.o3d_vis.clear_geometries()
            try:
                del self.pcd
890
            except (KeyError, AttributeError):
891
                pass
892
            if save_path is not None:
893
894
895
                if not (save_path.endswith('.png')
                        or save_path.endswith('.jpg')):
                    save_path += '.png'
896
                self.o3d_vis.capture_screen_image(save_path)
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
            if self.flag_exit:
                self.o3d_vis.destroy_window()
                self.o3d_vis.close()
                self._clear_o3d_vis()
                sys.exit(0)

    def escape_callback(self, vis):
        self.o3d_vis.clear_geometries()
        self.o3d_vis.destroy_window()
        self.o3d_vis.close()
        self._clear_o3d_vis()
        sys.exit(0)

    def space_action_callback(self, vis, action, mods):
        if action == 1:
            if self.flag_pause:
                print_log(
                    'Playback continued, press [SPACE] to pause.',
                    logger='current')
            else:
                print_log(
                    'Playback paused, press [SPACE] to continue.',
                    logger='current')
            self.flag_pause = not self.flag_pause
        return True
922

923
924
925
    def right_callback(self, vis):
        self.flag_next = True
        return False
ZCMax's avatar
ZCMax committed
926

927
928
    # TODO: Support Visualize the 3D results from image and point cloud
    # respectively
ZCMax's avatar
ZCMax committed
929
930
931
932
    @master_only
    def add_datasample(self,
                       name: str,
                       data_input: dict,
933
                       data_sample: Optional[Det3DDataSample] = None,
ZCMax's avatar
ZCMax committed
934
935
936
937
938
                       draw_gt: bool = True,
                       draw_pred: bool = True,
                       show: bool = False,
                       wait_time: float = 0,
                       out_file: Optional[str] = None,
939
                       o3d_save_path: Optional[str] = None,
940
                       vis_task: str = 'mono_det',
ZCMax's avatar
ZCMax committed
941
                       pred_score_thr: float = 0.3,
942
943
                       step: int = 0,
                       show_pcd_rgb: bool = False) -> None:
ZCMax's avatar
ZCMax committed
944
945
        """Draw datasample and save to all backends.

946
947
948
949
950
        - If GT and prediction are plotted at the same time, they are displayed
          in a stitched image where the left image is the ground truth and the
          right image is the prediction.
        - If ``show`` is True, all storage backends are ignored, and the images
          will be displayed in a local window.
951
        - If ``out_file`` is specified, the drawn image will be saved to
952
          ``out_file``. It is usually used when the display is not available.
ZCMax's avatar
ZCMax committed
953
954
955
956
957

        Args:
            name (str): The image identifier.
            data_input (dict): It should include the point clouds or image
                to draw.
958
            data_sample (:obj:`Det3DDataSample`, optional): Prediction
ZCMax's avatar
ZCMax committed
959
960
                Det3DDataSample. Defaults to None.
            draw_gt (bool): Whether to draw GT Det3DDataSample.
961
                Defaults to True.
ZCMax's avatar
ZCMax committed
962
963
            draw_pred (bool): Whether to draw Prediction Det3DDataSample.
                Defaults to True.
964
965
            show (bool): Whether to display the drawn point clouds and image.
                Defaults to False.
ZCMax's avatar
ZCMax committed
966
            wait_time (float): The interval of show (s). Defaults to 0.
967
            out_file (str, optional): Path to output file. Defaults to None.
968
            o3d_save_path (str, optional): Path to save open3d visualized
969
970
                results. Defaults to None.
            vis_task (str): Visualization task. Defaults to 'mono_det'.
ZCMax's avatar
ZCMax committed
971
972
973
            pred_score_thr (float): The threshold to visualize the bboxes
                and masks. Defaults to 0.3.
            step (int): Global step value to record. Defaults to 0.
974
975
            show_pcd_rgb (bool): Whether to show RGB point cloud. Defaults to
                False.
ZCMax's avatar
ZCMax committed
976
        """
977
978
979
980
981
982
        assert vis_task in (
            'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg',
            'multi-modality_det'), f'got unexpected vis_task {vis_task}.'
        classes = self.dataset_meta.get('classes', None)
        # For object detection datasets, no palette is saved
        palette = self.dataset_meta.get('palette', None)
ZCMax's avatar
ZCMax committed
983
        ignore_index = self.dataset_meta.get('ignore_index', None)
984
        if vis_task == 'lidar_seg' and ignore_index is not None and 'pts_semantic_mask' in data_sample.gt_pts_seg:  # noqa: E501
985
            keep_index = data_sample.gt_pts_seg.pts_semantic_mask != ignore_index  # noqa: E501
Sun Jiahao's avatar
Sun Jiahao committed
986
987
        else:
            keep_index = None
ZCMax's avatar
ZCMax committed
988

989
990
991
992
993
        gt_data_3d = None
        pred_data_3d = None
        gt_img_data = None
        pred_img_data = None

994
995
996
997
998
999
        if not hasattr(self, 'o3d_vis') and vis_task in [
                'multi-view_det', 'lidar_det', 'lidar_seg',
                'multi-modality_det'
        ]:
            self.o3d_vis = self._initialize_o3d_vis(show=show)

1000
1001
1002
1003
        if draw_gt and data_sample is not None:
            if 'gt_instances_3d' in data_sample:
                gt_data_3d = self._draw_instances_3d(
                    data_input, data_sample.gt_instances_3d,
1004
                    data_sample.metainfo, vis_task, show_pcd_rgb, palette)
1005
            if 'gt_instances' in data_sample:
ChaimZhu's avatar
ChaimZhu committed
1006
1007
                if len(data_sample.gt_instances) > 0:
                    assert 'img' in data_input
1008
                    img = data_input['img']
ChaimZhu's avatar
ChaimZhu committed
1009
1010
1011
1012
1013
                    if isinstance(data_input['img'], Tensor):
                        img = data_input['img'].permute(1, 2, 0).numpy()
                        img = img[..., [2, 1, 0]]  # bgr to rgb
                    gt_img_data = self._draw_instances(
                        img, data_sample.gt_instances, classes, palette)
1014
            if 'gt_pts_seg' in data_sample and vis_task == 'lidar_seg':
ZCMax's avatar
ZCMax committed
1015
1016
                assert classes is not None, 'class information is ' \
                                            'not provided when ' \
1017
                                            'visualizing semantic ' \
ZCMax's avatar
ZCMax committed
1018
1019
                                            'segmentation results.'
                assert 'points' in data_input
1020
                self._draw_pts_sem_seg(data_input['points'],
1021
                                       data_sample.gt_pts_seg, palette,
1022
                                       keep_index)
ZCMax's avatar
ZCMax committed
1023

1024
1025
1026
        if draw_pred and data_sample is not None:
            if 'pred_instances_3d' in data_sample:
                pred_instances_3d = data_sample.pred_instances_3d
1027
                # .cpu can not be used for BaseInstance3DBoxes
1028
                # so we need to use .to('cpu')
ZCMax's avatar
ZCMax committed
1029
                pred_instances_3d = pred_instances_3d[
1030
                    pred_instances_3d.scores_3d > pred_score_thr].to('cpu')
ZCMax's avatar
ZCMax committed
1031
1032
                pred_data_3d = self._draw_instances_3d(data_input,
                                                       pred_instances_3d,
1033
                                                       data_sample.metainfo,
1034
1035
                                                       vis_task, show_pcd_rgb,
                                                       palette)
1036
1037
1038
            if 'pred_instances' in data_sample:
                if 'img' in data_input and len(data_sample.pred_instances) > 0:
                    pred_instances = data_sample.pred_instances
1039
                    pred_instances = pred_instances[
1040
                        pred_instances.scores > pred_score_thr].cpu()
1041
                    img = data_input['img']
1042
1043
1044
1045
1046
                    if isinstance(data_input['img'], Tensor):
                        img = data_input['img'].permute(1, 2, 0).numpy()
                        img = img[..., [2, 1, 0]]  # bgr to rgb
                    pred_img_data = self._draw_instances(
                        img, pred_instances, classes, palette)
1047
            if 'pred_pts_seg' in data_sample and vis_task == 'lidar_seg':
ZCMax's avatar
ZCMax committed
1048
1049
                assert classes is not None, 'class information is ' \
                                            'not provided when ' \
1050
                                            'visualizing semantic ' \
ZCMax's avatar
ZCMax committed
1051
1052
                                            'segmentation results.'
                assert 'points' in data_input
1053
1054
                self._draw_pts_sem_seg(data_input['points'],
                                       data_sample.pred_pts_seg, palette,
1055
                                       keep_index)
ZCMax's avatar
ZCMax committed
1056
1057

        # monocular 3d object detection image
1058
        if vis_task in ['mono_det', 'multi-modality_det']:
1059
1060
1061
1062
1063
1064
1065
            if gt_data_3d is not None and pred_data_3d is not None:
                drawn_img_3d = np.concatenate(
                    (gt_data_3d['img'], pred_data_3d['img']), axis=1)
            elif gt_data_3d is not None:
                drawn_img_3d = gt_data_3d['img']
            elif pred_data_3d is not None:
                drawn_img_3d = pred_data_3d['img']
1066
1067
            else:  # both instances of gt and pred are empty
                drawn_img_3d = None
ZCMax's avatar
ZCMax committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        else:
            drawn_img_3d = None

        # 2d object detection image
        if gt_img_data is not None and pred_img_data is not None:
            drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
        elif gt_img_data is not None:
            drawn_img = gt_img_data
        elif pred_img_data is not None:
            drawn_img = pred_img_data
        else:
            drawn_img = None

        if show:
            self.show(
1083
                o3d_save_path,
ZCMax's avatar
ZCMax committed
1084
1085
1086
                drawn_img_3d,
                drawn_img,
                win_name=name,
1087
1088
                wait_time=wait_time,
                vis_task=vis_task)
ZCMax's avatar
ZCMax committed
1089
1090

        if out_file is not None:
1091
1092
1093
            # check the suffix of the name of image file
            if not (out_file.endswith('.png') or out_file.endswith('.jpg')):
                out_file = f'{out_file}.png'
ZCMax's avatar
ZCMax committed
1094
            if drawn_img_3d is not None:
1095
                mmcv.imwrite(drawn_img_3d[..., ::-1], out_file)
ZCMax's avatar
ZCMax committed
1096
            if drawn_img is not None:
1097
1098
                mmcv.imwrite(drawn_img[..., ::-1],
                             out_file[:-4] + '_2d' + out_file[-4:])
ZCMax's avatar
ZCMax committed
1099
1100
        else:
            self.add_image(name, drawn_img_3d, step)