Unverified Commit 12e2b7cf authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Enhance] Support LiDAR visualization (#2611)

* need fix multi_modality

* update multi modal

* remove pcd

* fix mix bug

* fix nusence_mini

* fix msehframe

* fix flag exit

* add space line
parent 2c136730
...@@ -113,7 +113,7 @@ test_pipeline = [ ...@@ -113,7 +113,7 @@ test_pipeline = [
dataset_type='semantickitti', dataset_type='semantickitti',
backend_args=backend_args), backend_args=backend_args),
dict(type='PointSegClassMapping'), dict(type='PointSegClassMapping'),
dict(type='Pack3DDetInputs', keys=['points']) dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
......
...@@ -59,7 +59,7 @@ def main(args): ...@@ -59,7 +59,7 @@ def main(args):
data_sample=result, data_sample=result,
draw_gt=False, draw_gt=False,
show=args.show, show=args.show,
wait_time=0, wait_time=-1,
out_file=args.out_dir, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='mono_det') vis_task='mono_det')
......
...@@ -67,7 +67,7 @@ def main(args): ...@@ -67,7 +67,7 @@ def main(args):
data_sample=result, data_sample=result,
draw_gt=False, draw_gt=False,
show=args.show, show=args.show,
wait_time=0, wait_time=-1,
out_file=args.out_dir, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='multi-modality_det') vis_task='multi-modality_det')
......
...@@ -49,7 +49,7 @@ def main(args): ...@@ -49,7 +49,7 @@ def main(args):
data_sample=result, data_sample=result,
draw_gt=False, draw_gt=False,
show=args.show, show=args.show,
wait_time=0, wait_time=-1,
out_file=args.out_dir, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='lidar_det') vis_task='lidar_det')
......
...@@ -45,7 +45,7 @@ def main(args): ...@@ -45,7 +45,7 @@ def main(args):
data_sample=result, data_sample=result,
draw_gt=False, draw_gt=False,
show=args.show, show=args.show,
wait_time=0, wait_time=-1,
out_file=args.out_dir, out_file=args.out_dir,
vis_task='lidar_seg') vis_task='lidar_seg')
......
...@@ -7,6 +7,7 @@ import mmcv ...@@ -7,6 +7,7 @@ import mmcv
import numpy as np import numpy as np
from mmengine.fileio import get from mmengine.fileio import get
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.runner import Runner from mmengine.runner import Runner
from mmengine.utils import mkdir_or_exist from mmengine.utils import mkdir_or_exist
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
...@@ -56,6 +57,8 @@ class Det3DVisualizationHook(Hook): ...@@ -56,6 +57,8 @@ class Det3DVisualizationHook(Hook):
vis_task: str = 'mono_det', vis_task: str = 'mono_det',
wait_time: float = 0., wait_time: float = 0.,
test_out_dir: Optional[str] = None, test_out_dir: Optional[str] = None,
draw_gt: bool = True,
draw_pred: bool = True,
backend_args: Optional[dict] = None): backend_args: Optional[dict] = None):
self._visualizer: Visualizer = Visualizer.get_current_instance() self._visualizer: Visualizer = Visualizer.get_current_instance()
self.interval = interval self.interval = interval
...@@ -70,11 +73,20 @@ class Det3DVisualizationHook(Hook): ...@@ -70,11 +73,20 @@ class Det3DVisualizationHook(Hook):
'needs to be excluded.') 'needs to be excluded.')
self.vis_task = vis_task self.vis_task = vis_task
if wait_time == -1:
print_log(
'Manual control mode, press [Right] to next sample.',
logger='current')
else:
print_log(
'Autoplay mode, press [SPACE] to pause.', logger='current')
self.wait_time = wait_time self.wait_time = wait_time
self.backend_args = backend_args self.backend_args = backend_args
self.draw = draw self.draw = draw
self.test_out_dir = test_out_dir self.test_out_dir = test_out_dir
self._test_index = 0 self._test_index = 0
self.draw_gt = draw_gt
self.draw_pred = draw_pred
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[Det3DDataSample]) -> None: outputs: Sequence[Det3DDataSample]) -> None:
...@@ -208,6 +220,8 @@ class Det3DVisualizationHook(Hook): ...@@ -208,6 +220,8 @@ class Det3DVisualizationHook(Hook):
'test sample', 'test sample',
data_input, data_input,
data_sample=data_sample, data_sample=data_sample,
draw_gt=self.draw_gt,
draw_pred=self.draw_pred,
show=self.show, show=self.show,
vis_task=self.vis_task, vis_task=self.vis_task,
wait_time=self.wait_time, wait_time=self.wait_time,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import math import math
import sys
import time import time
from typing import List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
...@@ -12,6 +13,7 @@ from matplotlib.patches import PathPatch ...@@ -12,6 +13,7 @@ from matplotlib.patches import PathPatch
from matplotlib.path import Path from matplotlib.path import Path
from mmdet.visualization import DetLocalVisualizer, get_palette from mmdet.visualization import DetLocalVisualizer, get_palette
from mmengine.dist import master_only from mmengine.dist import master_only
from mmengine.logging import print_log
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmengine.visualization import Visualizer as MMENGINE_Visualizer from mmengine.visualization import Visualizer as MMENGINE_Visualizer
from mmengine.visualization.utils import (check_type, color_val_matplotlib, from mmengine.visualization.utils import (check_type, color_val_matplotlib,
...@@ -136,19 +138,24 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -136,19 +138,24 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
alpha=alpha) alpha=alpha)
if points is not None: if points is not None:
self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg) self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg)
self.pts_seg_num = 0
self.multi_imgs_col = multi_imgs_col self.multi_imgs_col = multi_imgs_col
self.fig_show_cfg.update(fig_show_cfg) self.fig_show_cfg.update(fig_show_cfg)
self.flag_pause = False
self.flag_next = False
self.flag_exit = False
def _clear_o3d_vis(self) -> None: def _clear_o3d_vis(self) -> None:
"""Clear open3d vis.""" """Clear open3d vis."""
if hasattr(self, 'o3d_vis'): if hasattr(self, 'o3d_vis'):
del self.o3d_vis del self.o3d_vis
del self.pcd
del self.points_colors del self.points_colors
del self.view_control
if hasattr(self, 'pcd'):
del self.pcd
def _initialize_o3d_vis(self, frame_cfg: dict) -> Visualizer: def _initialize_o3d_vis(self) -> Visualizer:
"""Initialize open3d vis according to frame_cfg. """Initialize open3d vis according to frame_cfg.
Args: Args:
...@@ -161,11 +168,16 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -161,11 +168,16 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if o3d is None or geometry is None: if o3d is None or geometry is None:
raise ImportError( raise ImportError(
'Please run "pip install open3d" to install open3d first.') 'Please run "pip install open3d" to install open3d first.')
o3d_vis = o3d.visualization.Visualizer() glfw_key_escape = 256 # Esc
glfw_key_space = 32 # Space
glfw_key_right = 262 # Right
o3d_vis = o3d.visualization.VisualizerWithKeyCallback()
o3d_vis.register_key_callback(glfw_key_escape, self.escape_callback)
o3d_vis.register_key_action_callback(glfw_key_space,
self.space_action_callback)
o3d_vis.register_key_callback(glfw_key_right, self.right_callback)
o3d_vis.create_window() o3d_vis.create_window()
# create coordinate frame self.view_control = o3d_vis.get_view_control()
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg)
o3d_vis.add_geometry(mesh_frame)
return o3d_vis return o3d_vis
@master_only @master_only
...@@ -205,7 +217,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -205,7 +217,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
check_type('points', points, np.ndarray) check_type('points', points, np.ndarray)
if not hasattr(self, 'o3d_vis'): if not hasattr(self, 'o3d_vis'):
self.o3d_vis = self._initialize_o3d_vis(frame_cfg) self.o3d_vis = self._initialize_o3d_vis()
# for now we convert points into depth mode for visualization # for now we convert points into depth mode for visualization
if pcd_mode != Coord3DMode.DEPTH: if pcd_mode != Coord3DMode.DEPTH:
...@@ -235,6 +247,10 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -235,6 +247,10 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else: else:
raise NotImplementedError raise NotImplementedError
# create coordinate frame
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg)
self.o3d_vis.add_geometry(mesh_frame)
pcd.colors = o3d.utility.Vector3dVector(points_colors) pcd.colors = o3d.utility.Vector3dVector(points_colors)
self.o3d_vis.add_geometry(pcd) self.o3d_vis.add_geometry(pcd)
self.pcd = pcd self.pcd = pcd
...@@ -572,12 +588,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -572,12 +588,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# we can't draw the colors on existing points # we can't draw the colors on existing points
# in case gt and pred mask would overlap # in case gt and pred mask would overlap
# instead we set a large offset along x-axis for each seg mask # instead we set a large offset along x-axis for each seg mask
self.pts_seg_num += 1 if hasattr(self, 'pcd'):
offset = (np.array(self.pcd.points).max(0) - offset = (np.array(self.pcd.points).max(0) -
np.array(self.pcd.points).min(0))[0] * 1.2 * self.pts_seg_num np.array(self.pcd.points).min(0))[0] * 1.2
mesh_frame = geometry.TriangleMesh.create_coordinate_frame( mesh_frame = geometry.TriangleMesh.create_coordinate_frame(
size=1, origin=[offset, 0, 0]) # create coordinate frame for seg size=1, origin=[offset, 0,
self.o3d_vis.add_geometry(mesh_frame) 0]) # create coordinate frame for seg
self.o3d_vis.add_geometry(mesh_frame)
else:
offset = 0
seg_points = copy.deepcopy(seg_mask_colors) seg_points = copy.deepcopy(seg_mask_colors)
seg_points[:, 0] += offset seg_points[:, 0] += offset
self.set_points(seg_points, pcd_mode=2, vis_mode='add', mode='xyzrgb') self.set_points(seg_points, pcd_mode=2, vis_mode='add', mode='xyzrgb')
...@@ -716,7 +735,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -716,7 +735,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
points: Union[Tensor, np.ndarray], points: Union[Tensor, np.ndarray],
pts_seg: PointData, pts_seg: PointData,
palette: Optional[List[tuple]] = None, palette: Optional[List[tuple]] = None,
ignore_index: Optional[int] = None) -> None: keep_index: Optional[int] = None) -> None:
"""Draw 3D semantic mask of GT or prediction. """Draw 3D semantic mask of GT or prediction.
Args: Args:
...@@ -733,14 +752,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -733,14 +752,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pts_sem_seg = tensor2ndarray(pts_seg.pts_semantic_mask) pts_sem_seg = tensor2ndarray(pts_seg.pts_semantic_mask)
palette = np.array(palette) palette = np.array(palette)
if ignore_index is not None: if keep_index is not None:
points = points[pts_sem_seg != ignore_index] keep_index = tensor2ndarray(keep_index)
pts_sem_seg = pts_sem_seg[pts_sem_seg != ignore_index] points = points[keep_index]
pts_sem_seg = pts_sem_seg[keep_index]
pts_color = palette[pts_sem_seg] pts_color = palette[pts_sem_seg]
seg_color = np.concatenate([points[:, :3], pts_color], axis=1) seg_color = np.concatenate([points[:, :3], pts_color], axis=1)
self.set_points(points, pcd_mode=2, vis_mode='add')
self.draw_seg_mask(seg_color) self.draw_seg_mask(seg_color)
@master_only @master_only
...@@ -749,8 +768,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -749,8 +768,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
drawn_img_3d: Optional[np.ndarray] = None, drawn_img_3d: Optional[np.ndarray] = None,
drawn_img: Optional[np.ndarray] = None, drawn_img: Optional[np.ndarray] = None,
win_name: str = 'image', win_name: str = 'image',
wait_time: int = 0, wait_time: int = -1,
continue_key: str = ' ', continue_key: str = 'right',
vis_task: str = 'lidar_det') -> None: vis_task: str = 'lidar_det') -> None:
"""Show the drawn point cloud/image. """Show the drawn point cloud/image.
...@@ -768,10 +787,6 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -768,10 +787,6 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
means "forever". Defaults to 0. means "forever". Defaults to 0.
continue_key (str): The key for users to continue. Defaults to ' '. continue_key (str): The key for users to continue. Defaults to ' '.
""" """
if vis_task == 'multi-modality_det':
img_wait_time = 0.5
else:
img_wait_time = wait_time
# In order to show multi-modal results at the same time, we show image # In order to show multi-modal results at the same time, we show image
# firstly and then show point cloud since the running of # firstly and then show point cloud since the running of
...@@ -779,34 +794,119 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -779,34 +794,119 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if hasattr(self, '_image'): if hasattr(self, '_image'):
if drawn_img is None and drawn_img_3d is None: if drawn_img is None and drawn_img_3d is None:
# use the image got by Visualizer.get_image() # use the image got by Visualizer.get_image()
super().show(drawn_img_3d, win_name, img_wait_time, if vis_task == 'multi-modality_det':
continue_key) import matplotlib.pyplot as plt
else: is_inline = 'inline' in plt.get_backend()
if drawn_img_3d is not None: img = self.get_image() if drawn_img is None else drawn_img
super().show(drawn_img_3d, win_name, img_wait_time, self._init_manager(win_name)
continue_key) fig = self.manager.canvas.figure
if drawn_img is not None: # remove white edges by set subplot margin
super().show(drawn_img, win_name, img_wait_time, fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
fig.clear()
ax = fig.add_subplot()
ax.axis(False)
ax.imshow(img)
self.manager.canvas.draw()
if is_inline:
return fig
else:
fig.show()
self.manager.canvas.flush_events()
else:
super().show(drawn_img_3d, win_name, wait_time,
continue_key) continue_key)
else:
if vis_task == 'multi-modality_det':
import matplotlib.pyplot as plt
is_inline = 'inline' in plt.get_backend()
img = drawn_img if drawn_img_3d is None else drawn_img_3d
self._init_manager(win_name)
fig = self.manager.canvas.figure
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
fig.clear()
ax = fig.add_subplot()
ax.axis(False)
ax.imshow(img)
self.manager.canvas.draw()
if is_inline:
return fig
else:
fig.show()
self.manager.canvas.flush_events()
else:
if drawn_img_3d is not None:
super().show(drawn_img_3d, win_name, wait_time,
continue_key)
if drawn_img is not None:
super().show(drawn_img, win_name, wait_time,
continue_key)
if hasattr(self, 'o3d_vis'): if hasattr(self, 'o3d_vis'):
self.o3d_vis.poll_events() if hasattr(self, 'view_port'):
self.view_control.convert_from_pinhole_camera_parameters(
self.view_port)
self.flag_exit = not self.o3d_vis.poll_events()
self.o3d_vis.update_renderer() self.o3d_vis.update_renderer()
if wait_time > 0: self.view_port = \
time.sleep(wait_time) self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
if wait_time != -1:
self.last_time = time.time()
while time.time(
) - self.last_time < wait_time and self.o3d_vis.poll_events():
self.o3d_vis.update_renderer()
self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
while self.flag_pause and self.o3d_vis.poll_events():
self.o3d_vis.update_renderer()
self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
else: else:
self.o3d_vis.run() while not self.flag_next and self.o3d_vis.poll_events():
self.o3d_vis.update_renderer()
self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
self.flag_next = False
self.o3d_vis.clear_geometries()
try:
del self.pcd
except KeyError:
pass
if save_path is not None: if save_path is not None:
if not (save_path.endswith('.png') if not (save_path.endswith('.png')
or save_path.endswith('.jpg')): or save_path.endswith('.jpg')):
save_path += '.png' save_path += '.png'
self.o3d_vis.capture_screen_image(save_path) self.o3d_vis.capture_screen_image(save_path)
if self.flag_exit:
self.o3d_vis.destroy_window()
self.o3d_vis.close()
self._clear_o3d_vis()
sys.exit(0)
def escape_callback(self, vis):
self.o3d_vis.clear_geometries()
self.o3d_vis.destroy_window()
self.o3d_vis.close()
self._clear_o3d_vis()
sys.exit(0)
def space_action_callback(self, vis, action, mods):
if action == 1:
if self.flag_pause:
print_log(
'Playback continued, press [SPACE] to pause.',
logger='current')
else:
print_log(
'Playback paused, press [SPACE] to continue.',
logger='current')
self.flag_pause = not self.flag_pause
return True
# TODO: support more flexible window control def right_callback(self, vis):
self.o3d_vis.clear_geometries() self.flag_next = True
self.o3d_vis.destroy_window() return False
self.o3d_vis.close()
self._clear_o3d_vis()
# TODO: Support Visualize the 3D results from image and point cloud # TODO: Support Visualize the 3D results from image and point cloud
# respectively # respectively
...@@ -862,6 +962,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -862,6 +962,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# For object detection datasets, no palette is saved # For object detection datasets, no palette is saved
palette = self.dataset_meta.get('palette', None) palette = self.dataset_meta.get('palette', None)
ignore_index = self.dataset_meta.get('ignore_index', None) ignore_index = self.dataset_meta.get('ignore_index', None)
if ignore_index is not None and 'gt_pts_seg' in data_sample and vis_task == 'lidar_seg': # noqa: E501
keep_index = data_sample.gt_pts_seg.pts_semantic_mask != ignore_index # noqa: E501
gt_data_3d = None gt_data_3d = None
pred_data_3d = None pred_data_3d = None
...@@ -890,7 +992,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -890,7 +992,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert 'points' in data_input assert 'points' in data_input
self._draw_pts_sem_seg(data_input['points'], self._draw_pts_sem_seg(data_input['points'],
data_sample.gt_pts_seg, palette, data_sample.gt_pts_seg, palette,
ignore_index) keep_index)
if draw_pred and data_sample is not None: if draw_pred and data_sample is not None:
if 'pred_instances_3d' in data_sample: if 'pred_instances_3d' in data_sample:
...@@ -922,7 +1024,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -922,7 +1024,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert 'points' in data_input assert 'points' in data_input
self._draw_pts_sem_seg(data_input['points'], self._draw_pts_sem_seg(data_input['points'],
data_sample.pred_pts_seg, palette, data_sample.pred_pts_seg, palette,
ignore_index) keep_index)
# monocular 3d object detection image # monocular 3d object detection image
if vis_task in ['mono_det', 'multi-modality_det']: if vis_task in ['mono_det', 'multi-modality_det']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment