Unverified Commit e990105e authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Fix] Fix indoor det visualization (#2625)

* fix visual

* add AttributeError

* fix keep index bug

* fix open3d version bug
parent fb5a3232
......@@ -75,7 +75,8 @@ class Pack3DDetInputs(BaseTransform):
'affine_aug', 'sweep_img_metas', 'ori_cam2img',
'cam2global', 'crop_offset', 'img_crop_offset',
'resize_img_shape', 'lidar2cam', 'ori_lidar2img',
'num_ref_frames', 'num_views', 'ego2global')
'num_ref_frames', 'num_views', 'ego2global',
'axis_align_matrix')
) -> None:
self.keys = keys
self.meta_keys = meta_keys
......
......@@ -43,6 +43,10 @@ class Det3DVisualizationHook(Hook):
show (bool): Whether to display the drawn image. Default to False.
vis_task (str): Visualization task. Defaults to 'mono_det'.
wait_time (float): The interval of show (s). Defaults to 0.
draw_gt (bool): Whether to draw ground truth. Defaults to True.
draw_pred (bool): Whether to draw prediction. Defaults to True.
show_pcd_rgb (bool): Whether to show RGB point cloud. Defaults to
False.
test_out_dir (str, optional): directory where painted images
will be saved in testing process.
backend_args (dict, optional): Arguments to instantiate the
......@@ -57,8 +61,9 @@ class Det3DVisualizationHook(Hook):
vis_task: str = 'mono_det',
wait_time: float = 0.,
test_out_dir: Optional[str] = None,
draw_gt: bool = True,
draw_gt: bool = False,
draw_pred: bool = True,
show_pcd_rgb: bool = False,
backend_args: Optional[dict] = None):
self._visualizer: Visualizer = Visualizer.get_current_instance()
self.interval = interval
......@@ -87,6 +92,7 @@ class Det3DVisualizationHook(Hook):
self._test_index = 0
self.draw_gt = draw_gt
self.draw_pred = draw_pred
self.show_pcd_rgb = show_pcd_rgb
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[Det3DDataSample]) -> None:
......@@ -142,11 +148,14 @@ class Det3DVisualizationHook(Hook):
'val sample',
data_input,
data_sample=outputs[0],
draw_gt=self.draw_gt,
draw_pred=self.draw_pred,
show=self.show,
vis_task=self.vis_task,
wait_time=self.wait_time,
pred_score_thr=self.score_thr,
step=total_curr_iter)
step=total_curr_iter,
show_pcd_rgb=self.show_pcd_rgb)
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[Det3DDataSample]) -> None:
......@@ -228,4 +237,5 @@ class Det3DVisualizationHook(Hook):
pred_score_thr=self.score_thr,
out_file=out_file,
o3d_save_path=o3d_save_path,
step=self._test_index)
step=self._test_index,
show_pcd_rgb=self.show_pcd_rgb)
......@@ -23,9 +23,9 @@ from torch import Tensor
from mmdet3d.registry import VISUALIZERS
from mmdet3d.structures import (BaseInstance3DBoxes, Box3DMode,
CameraInstance3DBoxes, Coord3DMode,
DepthInstance3DBoxes, Det3DDataSample,
LiDARInstance3DBoxes, PointData,
points_cam2img)
DepthInstance3DBoxes, DepthPoints,
Det3DDataSample, LiDARInstance3DBoxes,
PointData, points_cam2img)
from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img,
proj_lidar_bbox3d_to_img, to_depth_mode)
......@@ -293,7 +293,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# convert bboxes to numpy dtype
bboxes_3d = tensor2ndarray(bboxes_3d.tensor)
in_box_color = np.array(points_in_box_color)
# in_box_color = np.array(points_in_box_color)
for i in range(len(bboxes_3d)):
center = bboxes_3d[i, 0:3]
......@@ -320,7 +320,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if self.pcd is not None and mode == 'xyz':
indices = box3d.get_point_indices_within_bounding_box(
self.pcd.points)
self.points_colors[indices] = in_box_color
self.points_colors[indices] = np.array(bbox_color[i]) / 255.
# update points colors
if self.pcd is not None:
......@@ -606,6 +606,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
instances: InstanceData,
input_meta: dict,
vis_task: str,
show_pcd_rgb: bool = False,
palette: Optional[List[tuple]] = None) -> dict:
"""Draw 3D instances of GT or prediction.
......@@ -616,6 +617,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
input_meta (dict): Meta information.
vis_task (str): Visualization task, it includes: 'lidar_det',
'multi-modality_det', 'mono_det'.
show_pcd_rgb (bool): Whether to show RGB point cloud.
palette (List[tuple], optional): Palette information corresponding
to the category. Defaults to None.
......@@ -643,13 +645,22 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else:
bboxes_3d_depth = bboxes_3d.clone()
if 'axis_align_matrix' in input_meta:
points = DepthPoints(points, points_dim=points.shape[1])
rot_mat = input_meta['axis_align_matrix'][:3, :3]
trans_vec = input_meta['axis_align_matrix'][:3, -1]
points.rotate(rot_mat.T)
points.translate(trans_vec)
points = tensor2ndarray(points.tensor)
max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \
else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d]
self.set_points(points, pcd_mode=2)
self.set_points(
points, pcd_mode=2, mode='xyzrgb' if show_pcd_rgb else 'xyz')
self.draw_bboxes_3d(bboxes_3d_depth, bbox_color=colors)
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
......@@ -871,7 +882,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.o3d_vis.clear_geometries()
try:
del self.pcd
except KeyError:
except (KeyError, AttributeError):
pass
if save_path is not None:
if not (save_path.endswith('.png')
......@@ -923,7 +934,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
o3d_save_path: Optional[str] = None,
vis_task: str = 'mono_det',
pred_score_thr: float = 0.3,
step: int = 0) -> None:
step: int = 0,
show_pcd_rgb: bool = False) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are displayed
......@@ -954,6 +966,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0.
show_pcd_rgb (bool): Whether to show RGB point cloud. Defaults to
False.
"""
assert vis_task in (
'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg',
......@@ -976,7 +990,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if 'gt_instances_3d' in data_sample:
gt_data_3d = self._draw_instances_3d(
data_input, data_sample.gt_instances_3d,
data_sample.metainfo, vis_task, palette)
data_sample.metainfo, vis_task, show_pcd_rgb, palette)
if 'gt_instances' in data_sample:
if len(data_sample.gt_instances) > 0:
assert 'img' in data_input
......@@ -1006,7 +1020,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pred_data_3d = self._draw_instances_3d(data_input,
pred_instances_3d,
data_sample.metainfo,
vis_task, palette)
vis_task, show_pcd_rgb,
palette)
if 'pred_instances' in data_sample:
if 'img' in data_input and len(data_sample.pred_instances) > 0:
pred_instances = data_sample.pred_instances
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment