Unverified Commit 12e2b7cf authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Enhance] Support LiDAR visualization (#2611)

* need fix multi_modality

* update multi modal

* remove pcd

* fix mix bug

* fix nusence_mini

* fix msehframe

* fix flag exit

* add space line
parent 2c136730
......@@ -113,7 +113,7 @@ test_pipeline = [
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping'),
dict(type='Pack3DDetInputs', keys=['points'])
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
......
......@@ -59,7 +59,7 @@ def main(args):
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
wait_time=-1,
out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='mono_det')
......
......@@ -67,7 +67,7 @@ def main(args):
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
wait_time=-1,
out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='multi-modality_det')
......
......@@ -49,7 +49,7 @@ def main(args):
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
wait_time=-1,
out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='lidar_det')
......
......@@ -45,7 +45,7 @@ def main(args):
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
wait_time=-1,
out_file=args.out_dir,
vis_task='lidar_seg')
......
......@@ -7,6 +7,7 @@ import mmcv
import numpy as np
from mmengine.fileio import get
from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.runner import Runner
from mmengine.utils import mkdir_or_exist
from mmengine.visualization import Visualizer
......@@ -56,6 +57,8 @@ class Det3DVisualizationHook(Hook):
vis_task: str = 'mono_det',
wait_time: float = 0.,
test_out_dir: Optional[str] = None,
draw_gt: bool = True,
draw_pred: bool = True,
backend_args: Optional[dict] = None):
self._visualizer: Visualizer = Visualizer.get_current_instance()
self.interval = interval
......@@ -70,11 +73,20 @@ class Det3DVisualizationHook(Hook):
'needs to be excluded.')
self.vis_task = vis_task
if wait_time == -1:
print_log(
'Manual control mode, press [Right] to next sample.',
logger='current')
else:
print_log(
'Autoplay mode, press [SPACE] to pause.', logger='current')
self.wait_time = wait_time
self.backend_args = backend_args
self.draw = draw
self.test_out_dir = test_out_dir
self._test_index = 0
self.draw_gt = draw_gt
self.draw_pred = draw_pred
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[Det3DDataSample]) -> None:
......@@ -208,6 +220,8 @@ class Det3DVisualizationHook(Hook):
'test sample',
data_input,
data_sample=data_sample,
draw_gt=self.draw_gt,
draw_pred=self.draw_pred,
show=self.show,
vis_task=self.vis_task,
wait_time=self.wait_time,
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import sys
import time
from typing import List, Optional, Sequence, Tuple, Union
......@@ -12,6 +13,7 @@ from matplotlib.patches import PathPatch
from matplotlib.path import Path
from mmdet.visualization import DetLocalVisualizer, get_palette
from mmengine.dist import master_only
from mmengine.logging import print_log
from mmengine.structures import InstanceData
from mmengine.visualization import Visualizer as MMENGINE_Visualizer
from mmengine.visualization.utils import (check_type, color_val_matplotlib,
......@@ -136,19 +138,24 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
alpha=alpha)
if points is not None:
self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg)
self.pts_seg_num = 0
self.multi_imgs_col = multi_imgs_col
self.fig_show_cfg.update(fig_show_cfg)
self.flag_pause = False
self.flag_next = False
self.flag_exit = False
def _clear_o3d_vis(self) -> None:
"""Clear open3d vis."""
if hasattr(self, 'o3d_vis'):
del self.o3d_vis
del self.pcd
del self.points_colors
del self.view_control
if hasattr(self, 'pcd'):
del self.pcd
def _initialize_o3d_vis(self, frame_cfg: dict) -> Visualizer:
def _initialize_o3d_vis(self) -> Visualizer:
"""Initialize open3d vis according to frame_cfg.
Args:
......@@ -161,11 +168,16 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if o3d is None or geometry is None:
raise ImportError(
'Please run "pip install open3d" to install open3d first.')
o3d_vis = o3d.visualization.Visualizer()
glfw_key_escape = 256 # Esc
glfw_key_space = 32 # Space
glfw_key_right = 262 # Right
o3d_vis = o3d.visualization.VisualizerWithKeyCallback()
o3d_vis.register_key_callback(glfw_key_escape, self.escape_callback)
o3d_vis.register_key_action_callback(glfw_key_space,
self.space_action_callback)
o3d_vis.register_key_callback(glfw_key_right, self.right_callback)
o3d_vis.create_window()
# create coordinate frame
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg)
o3d_vis.add_geometry(mesh_frame)
self.view_control = o3d_vis.get_view_control()
return o3d_vis
@master_only
......@@ -205,7 +217,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
check_type('points', points, np.ndarray)
if not hasattr(self, 'o3d_vis'):
self.o3d_vis = self._initialize_o3d_vis(frame_cfg)
self.o3d_vis = self._initialize_o3d_vis()
# for now we convert points into depth mode for visualization
if pcd_mode != Coord3DMode.DEPTH:
......@@ -235,6 +247,10 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else:
raise NotImplementedError
# create coordinate frame
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg)
self.o3d_vis.add_geometry(mesh_frame)
pcd.colors = o3d.utility.Vector3dVector(points_colors)
self.o3d_vis.add_geometry(pcd)
self.pcd = pcd
......@@ -572,12 +588,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# we can't draw the colors on existing points
# in case gt and pred mask would overlap
# instead we set a large offset along x-axis for each seg mask
self.pts_seg_num += 1
offset = (np.array(self.pcd.points).max(0) -
np.array(self.pcd.points).min(0))[0] * 1.2 * self.pts_seg_num
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(
size=1, origin=[offset, 0, 0]) # create coordinate frame for seg
self.o3d_vis.add_geometry(mesh_frame)
if hasattr(self, 'pcd'):
offset = (np.array(self.pcd.points).max(0) -
np.array(self.pcd.points).min(0))[0] * 1.2
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(
size=1, origin=[offset, 0,
0]) # create coordinate frame for seg
self.o3d_vis.add_geometry(mesh_frame)
else:
offset = 0
seg_points = copy.deepcopy(seg_mask_colors)
seg_points[:, 0] += offset
self.set_points(seg_points, pcd_mode=2, vis_mode='add', mode='xyzrgb')
......@@ -716,7 +735,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
points: Union[Tensor, np.ndarray],
pts_seg: PointData,
palette: Optional[List[tuple]] = None,
ignore_index: Optional[int] = None) -> None:
keep_index: Optional[int] = None) -> None:
"""Draw 3D semantic mask of GT or prediction.
Args:
......@@ -733,14 +752,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pts_sem_seg = tensor2ndarray(pts_seg.pts_semantic_mask)
palette = np.array(palette)
if ignore_index is not None:
points = points[pts_sem_seg != ignore_index]
pts_sem_seg = pts_sem_seg[pts_sem_seg != ignore_index]
if keep_index is not None:
keep_index = tensor2ndarray(keep_index)
points = points[keep_index]
pts_sem_seg = pts_sem_seg[keep_index]
pts_color = palette[pts_sem_seg]
seg_color = np.concatenate([points[:, :3], pts_color], axis=1)
self.set_points(points, pcd_mode=2, vis_mode='add')
self.draw_seg_mask(seg_color)
@master_only
......@@ -749,8 +768,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
drawn_img_3d: Optional[np.ndarray] = None,
drawn_img: Optional[np.ndarray] = None,
win_name: str = 'image',
wait_time: int = 0,
continue_key: str = ' ',
wait_time: int = -1,
continue_key: str = 'right',
vis_task: str = 'lidar_det') -> None:
"""Show the drawn point cloud/image.
......@@ -768,10 +787,6 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
means "forever". Defaults to 0.
continue_key (str): The key for users to continue. Defaults to ' '.
"""
if vis_task == 'multi-modality_det':
img_wait_time = 0.5
else:
img_wait_time = wait_time
# In order to show multi-modal results at the same time, we show image
# firstly and then show point cloud since the running of
......@@ -779,34 +794,119 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if hasattr(self, '_image'):
if drawn_img is None and drawn_img_3d is None:
# use the image got by Visualizer.get_image()
super().show(drawn_img_3d, win_name, img_wait_time,
continue_key)
else:
if drawn_img_3d is not None:
super().show(drawn_img_3d, win_name, img_wait_time,
continue_key)
if drawn_img is not None:
super().show(drawn_img, win_name, img_wait_time,
if vis_task == 'multi-modality_det':
import matplotlib.pyplot as plt
is_inline = 'inline' in plt.get_backend()
img = self.get_image() if drawn_img is None else drawn_img
self._init_manager(win_name)
fig = self.manager.canvas.figure
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
fig.clear()
ax = fig.add_subplot()
ax.axis(False)
ax.imshow(img)
self.manager.canvas.draw()
if is_inline:
return fig
else:
fig.show()
self.manager.canvas.flush_events()
else:
super().show(drawn_img_3d, win_name, wait_time,
continue_key)
else:
if vis_task == 'multi-modality_det':
import matplotlib.pyplot as plt
is_inline = 'inline' in plt.get_backend()
img = drawn_img if drawn_img_3d is None else drawn_img_3d
self._init_manager(win_name)
fig = self.manager.canvas.figure
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
fig.clear()
ax = fig.add_subplot()
ax.axis(False)
ax.imshow(img)
self.manager.canvas.draw()
if is_inline:
return fig
else:
fig.show()
self.manager.canvas.flush_events()
else:
if drawn_img_3d is not None:
super().show(drawn_img_3d, win_name, wait_time,
continue_key)
if drawn_img is not None:
super().show(drawn_img, win_name, wait_time,
continue_key)
if hasattr(self, 'o3d_vis'):
self.o3d_vis.poll_events()
if hasattr(self, 'view_port'):
self.view_control.convert_from_pinhole_camera_parameters(
self.view_port)
self.flag_exit = not self.o3d_vis.poll_events()
self.o3d_vis.update_renderer()
if wait_time > 0:
time.sleep(wait_time)
self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
if wait_time != -1:
self.last_time = time.time()
while time.time(
) - self.last_time < wait_time and self.o3d_vis.poll_events():
self.o3d_vis.update_renderer()
self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
while self.flag_pause and self.o3d_vis.poll_events():
self.o3d_vis.update_renderer()
self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
else:
self.o3d_vis.run()
while not self.flag_next and self.o3d_vis.poll_events():
self.o3d_vis.update_renderer()
self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
self.flag_next = False
self.o3d_vis.clear_geometries()
try:
del self.pcd
except KeyError:
pass
if save_path is not None:
if not (save_path.endswith('.png')
or save_path.endswith('.jpg')):
save_path += '.png'
self.o3d_vis.capture_screen_image(save_path)
if self.flag_exit:
self.o3d_vis.destroy_window()
self.o3d_vis.close()
self._clear_o3d_vis()
sys.exit(0)
def escape_callback(self, vis):
self.o3d_vis.clear_geometries()
self.o3d_vis.destroy_window()
self.o3d_vis.close()
self._clear_o3d_vis()
sys.exit(0)
def space_action_callback(self, vis, action, mods):
if action == 1:
if self.flag_pause:
print_log(
'Playback continued, press [SPACE] to pause.',
logger='current')
else:
print_log(
'Playback paused, press [SPACE] to continue.',
logger='current')
self.flag_pause = not self.flag_pause
return True
# TODO: support more flexible window control
self.o3d_vis.clear_geometries()
self.o3d_vis.destroy_window()
self.o3d_vis.close()
self._clear_o3d_vis()
def right_callback(self, vis):
self.flag_next = True
return False
# TODO: Support Visualize the 3D results from image and point cloud
# respectively
......@@ -862,6 +962,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# For object detection datasets, no palette is saved
palette = self.dataset_meta.get('palette', None)
ignore_index = self.dataset_meta.get('ignore_index', None)
if ignore_index is not None and 'gt_pts_seg' in data_sample and vis_task == 'lidar_seg': # noqa: E501
keep_index = data_sample.gt_pts_seg.pts_semantic_mask != ignore_index # noqa: E501
gt_data_3d = None
pred_data_3d = None
......@@ -890,7 +992,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert 'points' in data_input
self._draw_pts_sem_seg(data_input['points'],
data_sample.gt_pts_seg, palette,
ignore_index)
keep_index)
if draw_pred and data_sample is not None:
if 'pred_instances_3d' in data_sample:
......@@ -922,7 +1024,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert 'points' in data_input
self._draw_pts_sem_seg(data_input['points'],
data_sample.pred_pts_seg, palette,
ignore_index)
keep_index)
# monocular 3d object detection image
if vis_task in ['mono_det', 'multi-modality_det']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment