Unverified Commit 997b026b authored by Wenhao Wu's avatar Wenhao Wu Committed by GitHub
Browse files

[Fix] enable visualization of the demo results online (#473)

parent e9d84fe5
...@@ -17,6 +17,12 @@ def main(): ...@@ -17,6 +17,12 @@ def main():
'--score-thr', type=float, default=0.0, help='bbox score threshold') '--score-thr', type=float, default=0.0, help='bbox score threshold')
parser.add_argument( parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results') '--out-dir', type=str, default='demo', help='dir to save results')
parser.add_argument(
'--show', action='store_true', help='show online visuliaztion results')
parser.add_argument(
'--snapshot',
action='store_true',
help='whether to save online visuliaztion results')
args = parser.parse_args() args = parser.parse_args()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
...@@ -25,7 +31,13 @@ def main(): ...@@ -25,7 +31,13 @@ def main():
result, data = inference_multi_modality_detector(model, args.pcd, result, data = inference_multi_modality_detector(model, args.pcd,
args.image, args.ann) args.image, args.ann)
# show the results # show the results
show_result_meshlab(data, result, args.out_dir, args.score_thr) show_result_meshlab(
data,
result,
args.out_dir,
args.score_thr,
show=args.show,
snapshot=args.snapshot)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -14,6 +14,12 @@ def main(): ...@@ -14,6 +14,12 @@ def main():
'--score-thr', type=float, default=0.0, help='bbox score threshold') '--score-thr', type=float, default=0.0, help='bbox score threshold')
parser.add_argument( parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results') '--out-dir', type=str, default='demo', help='dir to save results')
parser.add_argument(
'--show', action='store_true', help='show online visuliaztion results')
parser.add_argument(
'--snapshot',
action='store_true',
help='whether to save online visuliaztion results')
args = parser.parse_args() args = parser.parse_args()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
...@@ -21,7 +27,13 @@ def main(): ...@@ -21,7 +27,13 @@ def main():
# test a single image # test a single image
result, data = inference_detector(model, args.pcd) result, data = inference_detector(model, args.pcd)
# show the results # show the results
show_result_meshlab(data, result, args.out_dir, args.score_thr) show_result_meshlab(
data,
result,
args.out_dir,
args.score_thr,
show=args.show,
snapshot=args.snapshot)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -185,7 +185,12 @@ def inference_multi_modality_detector(model, pcd, image, ann_file): ...@@ -185,7 +185,12 @@ def inference_multi_modality_detector(model, pcd, image, ann_file):
return result, data return result, data
def show_result_meshlab(data, result, out_dir, score_thr=0.0): def show_result_meshlab(data,
result,
out_dir,
score_thr=0.0,
show=False,
snapshot=False):
"""Show result by meshlab. """Show result by meshlab.
Args: Args:
...@@ -193,6 +198,8 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0): ...@@ -193,6 +198,8 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0):
result (dict): Predicted result from model. result (dict): Predicted result from model.
out_dir (str): Directory to save visualized result. out_dir (str): Directory to save visualized result.
score_thr (float): Minimum score of bboxes to be shown. Default: 0.0 score_thr (float): Minimum score of bboxes to be shown. Default: 0.0
show (bool): Visualize the results online. Defaults to False.
snapshot (bool): Whether to save the online results. Defaults to False.
""" """
points = data['points'][0][0].cpu().numpy() points = data['points'][0][0].cpu().numpy()
pts_filename = data['img_metas'][0][0]['pts_filename'] pts_filename = data['img_metas'][0][0]['pts_filename']
...@@ -220,7 +227,14 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0): ...@@ -220,7 +227,14 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0):
show_bboxes = Box3DMode.convert(pred_bboxes, box_mode, Box3DMode.DEPTH) show_bboxes = Box3DMode.convert(pred_bboxes, box_mode, Box3DMode.DEPTH)
else: else:
show_bboxes = deepcopy(pred_bboxes) show_bboxes = deepcopy(pred_bboxes)
show_result(points, None, show_bboxes, out_dir, file_name, show=False) show_result(
points,
None,
show_bboxes,
out_dir,
file_name,
show=show,
snapshot=snapshot)
if 'img' not in data.keys(): if 'img' not in data.keys():
return out_dir, file_name return out_dir, file_name
...@@ -242,7 +256,7 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0): ...@@ -242,7 +256,7 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0):
data['img_metas'][0][0]['lidar2img'], data['img_metas'][0][0]['lidar2img'],
out_dir, out_dir,
file_name, file_name,
show=False) show=show)
elif box_mode == Box3DMode.DEPTH: elif box_mode == Box3DMode.DEPTH:
if 'calib' not in data.keys(): if 'calib' not in data.keys():
raise NotImplementedError( raise NotImplementedError(
...@@ -260,7 +274,7 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0): ...@@ -260,7 +274,7 @@ def show_result_meshlab(data, result, out_dir, score_thr=0.0):
file_name, file_name,
depth_bbox=True, depth_bbox=True,
img_metas=data['img_metas'][0][0], img_metas=data['img_metas'][0][0],
show=False) show=show)
else: else:
raise NotImplementedError( raise NotImplementedError(
f'visualization of {box_mode} bbox is not supported') f'visualization of {box_mode} bbox is not supported')
......
...@@ -70,7 +70,13 @@ def _write_oriented_bbox(scene_bbox, out_filename): ...@@ -70,7 +70,13 @@ def _write_oriented_bbox(scene_bbox, out_filename):
return return
def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True): def show_result(points,
gt_bboxes,
pred_bboxes,
out_dir,
filename,
show=False,
snapshot=False):
"""Convert results into format that is directly readable for meshlab. """Convert results into format that is directly readable for meshlab.
Args: Args:
...@@ -79,8 +85,12 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True): ...@@ -79,8 +85,12 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True):
pred_bboxes (np.ndarray): Predicted boxes. pred_bboxes (np.ndarray): Predicted boxes.
out_dir (str): Path of output directory out_dir (str): Path of output directory
filename (str): Filename of the current frame. filename (str): Filename of the current frame.
show (bool): Visualize the results online. Defaults to True. show (bool): Visualize the results online. Defaults to False.
snapshot (bool): Whether to save the online results. Defaults to False.
""" """
result_path = osp.join(out_dir, filename)
mmcv.mkdir_or_exist(result_path)
if show: if show:
from .open3d_vis import Visualizer from .open3d_vis import Visualizer
...@@ -89,10 +99,9 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True): ...@@ -89,10 +99,9 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True):
vis.add_bboxes(bbox3d=pred_bboxes) vis.add_bboxes(bbox3d=pred_bboxes)
if gt_bboxes is not None: if gt_bboxes is not None:
vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1)) vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1))
vis.show() show_path = osp.join(result_path,
f'{filename}_online.png') if snapshot else None
result_path = osp.join(out_dir, filename) vis.show(show_path)
mmcv.mkdir_or_exist(result_path)
if points is not None: if points is not None:
_write_obj(points, osp.join(result_path, f'{filename}_points.obj')) _write_obj(points, osp.join(result_path, f'{filename}_points.obj'))
...@@ -121,7 +130,8 @@ def show_seg_result(points, ...@@ -121,7 +130,8 @@ def show_seg_result(points,
filename, filename,
palette, palette,
ignore_index=None, ignore_index=None,
show=False): show=False,
snapshot=False):
"""Convert results into format that is directly readable for meshlab. """Convert results into format that is directly readable for meshlab.
Args: Args:
...@@ -134,6 +144,8 @@ def show_seg_result(points, ...@@ -134,6 +144,8 @@ def show_seg_result(points,
ignore_index (int, optional): The label index to be ignored, e.g. \ ignore_index (int, optional): The label index to be ignored, e.g. \
unannotated points. Defaults to None. unannotated points. Defaults to None.
show (bool, optional): Visualize the results online. Defaults to False. show (bool, optional): Visualize the results online. Defaults to False.
snapshot (bool, optional): Whether to save the online results. \
Defaults to False.
""" """
# we need 3D coordinates to visualize segmentation mask # we need 3D coordinates to visualize segmentation mask
if gt_seg is not None or pred_seg is not None: if gt_seg is not None or pred_seg is not None:
...@@ -156,6 +168,9 @@ def show_seg_result(points, ...@@ -156,6 +168,9 @@ def show_seg_result(points,
pred_seg_color = np.concatenate([points[:, :3], pred_seg_color], pred_seg_color = np.concatenate([points[:, :3], pred_seg_color],
axis=1) axis=1)
result_path = osp.join(out_dir, filename)
mmcv.mkdir_or_exist(result_path)
# online visualization of segmentation mask # online visualization of segmentation mask
# we show three masks in a row, scene_points, gt_mask, pred_mask # we show three masks in a row, scene_points, gt_mask, pred_mask
if show: if show:
...@@ -166,10 +181,9 @@ def show_seg_result(points, ...@@ -166,10 +181,9 @@ def show_seg_result(points,
vis.add_seg_mask(gt_seg_color) vis.add_seg_mask(gt_seg_color)
if pred_seg is not None: if pred_seg is not None:
vis.add_seg_mask(pred_seg_color) vis.add_seg_mask(pred_seg_color)
vis.show() show_path = osp.join(result_path,
f'{filename}_online.png') if snapshot else None
result_path = osp.join(out_dir, filename) vis.show(show_path)
mmcv.mkdir_or_exist(result_path)
if points is not None: if points is not None:
_write_obj(points, osp.join(result_path, f'{filename}_points.obj')) _write_obj(points, osp.join(result_path, f'{filename}_points.obj'))
...@@ -190,7 +204,7 @@ def show_multi_modality_result(img, ...@@ -190,7 +204,7 @@ def show_multi_modality_result(img,
filename, filename,
depth_bbox=False, depth_bbox=False,
img_metas=None, img_metas=None,
show=True, show=False,
gt_bbox_color=(61, 102, 255), gt_bbox_color=(61, 102, 255),
pred_bbox_color=(241, 101, 72)): pred_bbox_color=(241, 101, 72)):
"""Convert multi-modality detection results into 2D results. """Convert multi-modality detection results into 2D results.
...@@ -207,7 +221,7 @@ def show_multi_modality_result(img, ...@@ -207,7 +221,7 @@ def show_multi_modality_result(img,
filename (str): Filename of the current frame. filename (str): Filename of the current frame.
depth_bbox (bool): Whether we are projecting camera bbox or lidar bbox. depth_bbox (bool): Whether we are projecting camera bbox or lidar bbox.
img_metas (dict): Used in projecting cameta bbox. img_metas (dict): Used in projecting cameta bbox.
show (bool): Visualize the results online. Defaults to True. show (bool): Visualize the results online. Defaults to False.
gt_bbox_color (str or tuple(int)): Color of bbox lines. gt_bbox_color (str or tuple(int)): Color of bbox lines.
The tuple of color should be in BGR order. Default: (255, 102, 61) The tuple of color should be in BGR order. Default: (255, 102, 61)
pred_bbox_color (str or tuple(int)): Color of bbox lines. pred_bbox_color (str or tuple(int)): Color of bbox lines.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment