Unverified Commit e990105e authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Fix] Fix indoor det visualization (#2625)

* fix visual

* add AttributeError

* fix keep index bug

* fix open3d version bug
parent fb5a3232
...@@ -75,7 +75,8 @@ class Pack3DDetInputs(BaseTransform): ...@@ -75,7 +75,8 @@ class Pack3DDetInputs(BaseTransform):
'affine_aug', 'sweep_img_metas', 'ori_cam2img', 'affine_aug', 'sweep_img_metas', 'ori_cam2img',
'cam2global', 'crop_offset', 'img_crop_offset', 'cam2global', 'crop_offset', 'img_crop_offset',
'resize_img_shape', 'lidar2cam', 'ori_lidar2img', 'resize_img_shape', 'lidar2cam', 'ori_lidar2img',
'num_ref_frames', 'num_views', 'ego2global') 'num_ref_frames', 'num_views', 'ego2global',
'axis_align_matrix')
) -> None: ) -> None:
self.keys = keys self.keys = keys
self.meta_keys = meta_keys self.meta_keys = meta_keys
......
...@@ -43,6 +43,10 @@ class Det3DVisualizationHook(Hook): ...@@ -43,6 +43,10 @@ class Det3DVisualizationHook(Hook):
show (bool): Whether to display the drawn image. Default to False. show (bool): Whether to display the drawn image. Default to False.
vis_task (str): Visualization task. Defaults to 'mono_det'. vis_task (str): Visualization task. Defaults to 'mono_det'.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
draw_gt (bool): Whether to draw ground truth. Defaults to True.
draw_pred (bool): Whether to draw prediction. Defaults to True.
show_pcd_rgb (bool): Whether to show RGB point cloud. Defaults to
False.
test_out_dir (str, optional): directory where painted images test_out_dir (str, optional): directory where painted images
will be saved in testing process. will be saved in testing process.
backend_args (dict, optional): Arguments to instantiate the backend_args (dict, optional): Arguments to instantiate the
...@@ -57,8 +61,9 @@ class Det3DVisualizationHook(Hook): ...@@ -57,8 +61,9 @@ class Det3DVisualizationHook(Hook):
vis_task: str = 'mono_det', vis_task: str = 'mono_det',
wait_time: float = 0., wait_time: float = 0.,
test_out_dir: Optional[str] = None, test_out_dir: Optional[str] = None,
draw_gt: bool = True, draw_gt: bool = False,
draw_pred: bool = True, draw_pred: bool = True,
show_pcd_rgb: bool = False,
backend_args: Optional[dict] = None): backend_args: Optional[dict] = None):
self._visualizer: Visualizer = Visualizer.get_current_instance() self._visualizer: Visualizer = Visualizer.get_current_instance()
self.interval = interval self.interval = interval
...@@ -87,6 +92,7 @@ class Det3DVisualizationHook(Hook): ...@@ -87,6 +92,7 @@ class Det3DVisualizationHook(Hook):
self._test_index = 0 self._test_index = 0
self.draw_gt = draw_gt self.draw_gt = draw_gt
self.draw_pred = draw_pred self.draw_pred = draw_pred
self.show_pcd_rgb = show_pcd_rgb
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[Det3DDataSample]) -> None: outputs: Sequence[Det3DDataSample]) -> None:
...@@ -142,11 +148,14 @@ class Det3DVisualizationHook(Hook): ...@@ -142,11 +148,14 @@ class Det3DVisualizationHook(Hook):
'val sample', 'val sample',
data_input, data_input,
data_sample=outputs[0], data_sample=outputs[0],
draw_gt=self.draw_gt,
draw_pred=self.draw_pred,
show=self.show, show=self.show,
vis_task=self.vis_task, vis_task=self.vis_task,
wait_time=self.wait_time, wait_time=self.wait_time,
pred_score_thr=self.score_thr, pred_score_thr=self.score_thr,
step=total_curr_iter) step=total_curr_iter,
show_pcd_rgb=self.show_pcd_rgb)
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[Det3DDataSample]) -> None: outputs: Sequence[Det3DDataSample]) -> None:
...@@ -228,4 +237,5 @@ class Det3DVisualizationHook(Hook): ...@@ -228,4 +237,5 @@ class Det3DVisualizationHook(Hook):
pred_score_thr=self.score_thr, pred_score_thr=self.score_thr,
out_file=out_file, out_file=out_file,
o3d_save_path=o3d_save_path, o3d_save_path=o3d_save_path,
step=self._test_index) step=self._test_index,
show_pcd_rgb=self.show_pcd_rgb)
...@@ -23,9 +23,9 @@ from torch import Tensor ...@@ -23,9 +23,9 @@ from torch import Tensor
from mmdet3d.registry import VISUALIZERS from mmdet3d.registry import VISUALIZERS
from mmdet3d.structures import (BaseInstance3DBoxes, Box3DMode, from mmdet3d.structures import (BaseInstance3DBoxes, Box3DMode,
CameraInstance3DBoxes, Coord3DMode, CameraInstance3DBoxes, Coord3DMode,
DepthInstance3DBoxes, Det3DDataSample, DepthInstance3DBoxes, DepthPoints,
LiDARInstance3DBoxes, PointData, Det3DDataSample, LiDARInstance3DBoxes,
points_cam2img) PointData, points_cam2img)
from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img, from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img,
proj_lidar_bbox3d_to_img, to_depth_mode) proj_lidar_bbox3d_to_img, to_depth_mode)
...@@ -293,7 +293,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -293,7 +293,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# convert bboxes to numpy dtype # convert bboxes to numpy dtype
bboxes_3d = tensor2ndarray(bboxes_3d.tensor) bboxes_3d = tensor2ndarray(bboxes_3d.tensor)
in_box_color = np.array(points_in_box_color) # in_box_color = np.array(points_in_box_color)
for i in range(len(bboxes_3d)): for i in range(len(bboxes_3d)):
center = bboxes_3d[i, 0:3] center = bboxes_3d[i, 0:3]
...@@ -320,7 +320,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -320,7 +320,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if self.pcd is not None and mode == 'xyz': if self.pcd is not None and mode == 'xyz':
indices = box3d.get_point_indices_within_bounding_box( indices = box3d.get_point_indices_within_bounding_box(
self.pcd.points) self.pcd.points)
self.points_colors[indices] = in_box_color self.points_colors[indices] = np.array(bbox_color[i]) / 255.
# update points colors # update points colors
if self.pcd is not None: if self.pcd is not None:
...@@ -606,6 +606,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -606,6 +606,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
instances: InstanceData, instances: InstanceData,
input_meta: dict, input_meta: dict,
vis_task: str, vis_task: str,
show_pcd_rgb: bool = False,
palette: Optional[List[tuple]] = None) -> dict: palette: Optional[List[tuple]] = None) -> dict:
"""Draw 3D instances of GT or prediction. """Draw 3D instances of GT or prediction.
...@@ -616,6 +617,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -616,6 +617,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
input_meta (dict): Meta information. input_meta (dict): Meta information.
vis_task (str): Visualization task, it includes: 'lidar_det', vis_task (str): Visualization task, it includes: 'lidar_det',
'multi-modality_det', 'mono_det'. 'multi-modality_det', 'mono_det'.
show_pcd_rgb (bool): Whether to show RGB point cloud.
palette (List[tuple], optional): Palette information corresponding palette (List[tuple], optional): Palette information corresponding
to the category. Defaults to None. to the category. Defaults to None.
...@@ -643,13 +645,22 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -643,13 +645,22 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else: else:
bboxes_3d_depth = bboxes_3d.clone() bboxes_3d_depth = bboxes_3d.clone()
if 'axis_align_matrix' in input_meta:
points = DepthPoints(points, points_dim=points.shape[1])
rot_mat = input_meta['axis_align_matrix'][:3, :3]
trans_vec = input_meta['axis_align_matrix'][:3, -1]
points.rotate(rot_mat.T)
points.translate(trans_vec)
points = tensor2ndarray(points.tensor)
max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0) max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \ bbox_color = palette if self.bbox_color is None \
else self.bbox_color else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1) bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d] colors = [bbox_palette[label] for label in labels_3d]
self.set_points(points, pcd_mode=2) self.set_points(
points, pcd_mode=2, mode='xyzrgb' if show_pcd_rgb else 'xyz')
self.draw_bboxes_3d(bboxes_3d_depth, bbox_color=colors) self.draw_bboxes_3d(bboxes_3d_depth, bbox_color=colors)
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor) data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
...@@ -871,7 +882,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -871,7 +882,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.o3d_vis.clear_geometries() self.o3d_vis.clear_geometries()
try: try:
del self.pcd del self.pcd
except KeyError: except (KeyError, AttributeError):
pass pass
if save_path is not None: if save_path is not None:
if not (save_path.endswith('.png') if not (save_path.endswith('.png')
...@@ -923,7 +934,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -923,7 +934,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
o3d_save_path: Optional[str] = None, o3d_save_path: Optional[str] = None,
vis_task: str = 'mono_det', vis_task: str = 'mono_det',
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
step: int = 0) -> None: step: int = 0,
show_pcd_rgb: bool = False) -> None:
"""Draw datasample and save to all backends. """Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are displayed - If GT and prediction are plotted at the same time, they are displayed
...@@ -954,6 +966,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -954,6 +966,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pred_score_thr (float): The threshold to visualize the bboxes pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3. and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0. step (int): Global step value to record. Defaults to 0.
show_pcd_rgb (bool): Whether to show RGB point cloud. Defaults to
False.
""" """
assert vis_task in ( assert vis_task in (
'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg', 'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg',
...@@ -976,7 +990,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -976,7 +990,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if 'gt_instances_3d' in data_sample: if 'gt_instances_3d' in data_sample:
gt_data_3d = self._draw_instances_3d( gt_data_3d = self._draw_instances_3d(
data_input, data_sample.gt_instances_3d, data_input, data_sample.gt_instances_3d,
data_sample.metainfo, vis_task, palette) data_sample.metainfo, vis_task, show_pcd_rgb, palette)
if 'gt_instances' in data_sample: if 'gt_instances' in data_sample:
if len(data_sample.gt_instances) > 0: if len(data_sample.gt_instances) > 0:
assert 'img' in data_input assert 'img' in data_input
...@@ -1006,7 +1020,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -1006,7 +1020,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pred_data_3d = self._draw_instances_3d(data_input, pred_data_3d = self._draw_instances_3d(data_input,
pred_instances_3d, pred_instances_3d,
data_sample.metainfo, data_sample.metainfo,
vis_task, palette) vis_task, show_pcd_rgb,
palette)
if 'pred_instances' in data_sample: if 'pred_instances' in data_sample:
if 'img' in data_input and len(data_sample.pred_instances) > 0: if 'img' in data_input and len(data_sample.pred_instances) > 0:
pred_instances = data_sample.pred_instances pred_instances = data_sample.pred_instances
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment