Commit 3960f9a7 authored by ChaimZhu's avatar ChaimZhu Committed by ZwwWayne
Browse files

[Fix] fix point cloud loop visualization error (#1914)



* fix point cloud loop visualization error

* fix browse_dataset

* fix browse_dataset

* support saving lidar_det
Co-authored-by: default avatarJingweiZhang12 <zjw18@mails.tsinghua.edu.cn>
parent 7a220a9a
...@@ -120,7 +120,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -120,7 +120,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg) self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg)
self.pts_seg_num = 0 self.pts_seg_num = 0
def _initialize_o3d_vis(self, frame_cfg) -> tuple: def _clear_o3d_vis(self) -> None:
"""Clear open3d vis."""
if hasattr(self, 'o3d_vis'):
del self.o3d_vis
del self.pcd
del self.points_colors
def _initialize_o3d_vis(self, frame_cfg) -> o3d.visualization.Visualizer:
"""Initialize open3d vis according to frame_cfg. """Initialize open3d vis according to frame_cfg.
Args: Args:
...@@ -141,8 +149,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -141,8 +149,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
def set_points(self, def set_points(self,
points: np.ndarray, points: np.ndarray,
pcd_mode: int = 0, pcd_mode: int = 0,
vis_mode: str = 'replace',
frame_cfg: dict = dict(size=1, origin=[0, 0, 0]), frame_cfg: dict = dict(size=1, origin=[0, 0, 0]),
vis_task: str = 'lidar_det',
points_color: Tuple = (0.5, 0.5, 0.5), points_color: Tuple = (0.5, 0.5, 0.5),
points_size: int = 2, points_size: int = 2,
mode: str = 'xyz') -> None: mode: str = 'xyz') -> None:
...@@ -154,11 +162,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -154,11 +162,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pcd_mode (int): The point cloud mode (coordinates): pcd_mode (int): The point cloud mode (coordinates):
0 represents LiDAR, 1 represents CAMERA, 2 0 represents LiDAR, 1 represents CAMERA, 2
represents Depth. Defaults to 0. represents Depth. Defaults to 0.
vis_mode (str): The visualization mode in Open3D:
'replace': Replace the existing point cloud with
input point cloud.
'add': Add input point cloud into existing point
cloud.
Defaults to 'replace'.
frame_cfg (dict): The coordinate frame config while Open3D frame_cfg (dict): The coordinate frame config while Open3D
visualization initialization. visualization initialization.
Defaults to dict(size=1, origin=[0, 0, 0]). Defaults to dict(size=1, origin=[0, 0, 0]).
vis_task (str): Visualiztion task, it includes:
'lidar_det', 'multi-modality_det', 'mono_det', 'lidar_seg'.
point_color (tuple[float], optional): the color of points. point_color (tuple[float], optional): the color of points.
Default: (0.5, 0.5, 0.5). Default: (0.5, 0.5, 0.5).
points_size (int, optional): the size of points to show points_size (int, optional): the size of points to show
...@@ -167,6 +179,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -167,6 +179,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
available mode ['xyz', 'xyzrgb']. Default: 'xyz'. available mode ['xyz', 'xyzrgb']. Default: 'xyz'.
""" """
assert points is not None assert points is not None
assert vis_mode in ('replace', 'add')
check_type('points', points, np.ndarray) check_type('points', points, np.ndarray)
if not hasattr(self, 'o3d_vis'): if not hasattr(self, 'o3d_vis'):
...@@ -176,7 +189,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -176,7 +189,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if pcd_mode != Coord3DMode.DEPTH: if pcd_mode != Coord3DMode.DEPTH:
points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH) points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH)
if hasattr(self, 'pcd') and vis_task != 'lidar_seg': if hasattr(self, 'pcd') and vis_mode != 'add':
self.o3d_vis.remove_geometry(self.pcd) self.o3d_vis.remove_geometry(self.pcd)
# set points size in Open3D # set points size in Open3D
...@@ -524,8 +537,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -524,8 +537,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.o3d_vis.add_geometry(mesh_frame) self.o3d_vis.add_geometry(mesh_frame)
seg_points = copy.deepcopy(seg_mask_colors) seg_points = copy.deepcopy(seg_mask_colors)
seg_points[:, 0] += offset seg_points[:, 0] += offset
self.set_points( self.set_points(seg_points, pcd_mode=2, vis_mode='add', mode='xyzrgb')
seg_points, vis_task='lidar_seg', pcd_mode=2, mode='xyzrgb')
def _draw_instances_3d(self, data_input: dict, instances: InstanceData, def _draw_instances_3d(self, data_input: dict, instances: InstanceData,
input_meta: dict, vis_task: str, input_meta: dict, vis_task: str,
...@@ -559,7 +571,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -559,7 +571,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else: else:
bboxes_3d_depth = bboxes_3d.clone() bboxes_3d_depth = bboxes_3d.clone()
self.set_points(points, pcd_mode=2, vis_task=vis_task) self.set_points(points, pcd_mode=2)
self.draw_bboxes_3d(bboxes_3d_depth) self.draw_bboxes_3d(bboxes_3d_depth)
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor) data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
...@@ -614,7 +626,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -614,7 +626,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pts_color = palette[pts_sem_seg] pts_color = palette[pts_sem_seg]
seg_color = np.concatenate([points[:, :3], pts_color], axis=1) seg_color = np.concatenate([points[:, :3], pts_color], axis=1)
self.set_points(points, pcd_mode=2, vis_task='lidar_seg') self.set_points(points, pcd_mode=2, vis_mode='add')
self.draw_seg_mask(seg_color) self.draw_seg_mask(seg_color)
@master_only @master_only
...@@ -644,6 +656,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -644,6 +656,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if save_path is not None: if save_path is not None:
self.o3d_vis.capture_screen_image(save_path) self.o3d_vis.capture_screen_image(save_path)
self.o3d_vis.destroy_window() self.o3d_vis.destroy_window()
self._clear_o3d_vis()
if hasattr(self, '_image'): if hasattr(self, '_image'):
if drawn_img_3d is None: if drawn_img_3d is None:
...@@ -662,7 +675,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -662,7 +675,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
show: bool = False, show: bool = False,
wait_time: float = 0, wait_time: float = 0,
out_file: Optional[str] = None, out_file: Optional[str] = None,
save_path: Optional[str] = None, o3d_save_path: Optional[str] = None,
vis_task: str = 'mono_det', vis_task: str = 'mono_det',
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
step: int = 0) -> None: step: int = 0) -> None:
...@@ -673,9 +686,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -673,9 +686,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
ground truth and the right image is the prediction. ground truth and the right image is the prediction.
- If ``show`` is True, all storage backends are ignored, and - If ``show`` is True, all storage backends are ignored, and
the images will be displayed in a local window. the images will be displayed in a local window.
- If ``out_file`` is specified, the drawn point cloud or - If ``out_file`` is specified, the drawn image will be saved to
image will be saved to ``out_file``. t is usually used when ``out_file``. It is usually used when the display is not available.
the display is not available.
Args: Args:
name (str): The image identifier. name (str): The image identifier.
...@@ -691,8 +703,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -691,8 +703,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
image. Default to False. image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None. out_file (str): Path to output file. Defaults to None.
save_path (str, optional): Path to save open3d visualized results. o3d_save_path (str, optional): Path to save open3d visualized
Default: None. results Default: None.
vis-task (str): Visualization task. Defaults to 'mono_det'. vis-task (str): Visualization task. Defaults to 'mono_det'.
pred_score_thr (float): The threshold to visualize the bboxes pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3. and masks. Defaults to 0.3.
...@@ -786,8 +798,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -786,8 +798,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if show: if show:
self.show( self.show(
vis_task, o3d_save_path,
save_path,
drawn_img_3d, drawn_img_3d,
drawn_img, drawn_img,
win_name=name, win_name=name,
......
...@@ -71,7 +71,7 @@ def build_data_cfg(config_path, aug, cfg_options): ...@@ -71,7 +71,7 @@ def build_data_cfg(config_path, aug, cfg_options):
if aug: if aug:
show_pipeline = cfg.train_pipeline show_pipeline = cfg.train_pipeline
else: else:
show_pipeline = cfg.eval_pipeline show_pipeline = cfg.test_pipeline
for i in range(len(cfg.train_pipeline)): for i in range(len(cfg.train_pipeline)):
if cfg.train_pipeline[i]['type'] == 'LoadAnnotations3D': if cfg.train_pipeline[i]['type'] == 'LoadAnnotations3D':
show_pipeline.insert(i, cfg.train_pipeline[i]) show_pipeline.insert(i, cfg.train_pipeline[i])
...@@ -117,13 +117,20 @@ def main(): ...@@ -117,13 +117,20 @@ def main():
progress_bar = ProgressBar(len(dataset)) progress_bar = ProgressBar(len(dataset))
for item in dataset: for i, item in enumerate(dataset):
# the 3D Boxes in input could be in any of three coordinates # the 3D Boxes in input could be in any of three coordinates
data_input = item['inputs'] data_input = item['inputs']
data_sample = item['data_samples'].numpy() data_sample = item['data_samples'].numpy()
out_file = osp.join( out_file = osp.join(
args.output_dir) if args.output_dir is not None else None args.output_dir,
f'{i}.jpg') if args.output_dir is not None else None
# o3d_save_path is valid when args.not_show is False
o3d_save_path = osp.join(args.output_dir, f'pc_{i}.png') if (
args.output_dir is not None
and vis_task in ['lidar_det', 'lidar_seg', 'multi-modality_det']
and not args.not_show) else None
visualizer.add_datasample( visualizer.add_datasample(
'3d visualzier', '3d visualzier',
...@@ -132,6 +139,7 @@ def main(): ...@@ -132,6 +139,7 @@ def main():
show=not args.not_show, show=not args.not_show,
wait_time=args.show_interval, wait_time=args.show_interval,
out_file=out_file, out_file=out_file,
o3d_save_path=o3d_save_path,
vis_task=vis_task) vis_task=vis_task)
progress_bar.update() progress_bar.update()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment