Unverified Commit f84a17b8 authored by MilkClouds's avatar MilkClouds Committed by GitHub
Browse files

Label visualization (#1050)

* argument show, score_thr added to single_gpu_test

* Implemented label color visualization for show_result function

* Added show, score_thr argument for base 3 model(mmdetection3d)

* Fixed typo(color < 1) for show_result function

* Applied pre-commit run --all-files

* Revised documentation of show_result and revised variable name

* Updated documentation and set default value of score_thr to None
parent 5f1366ce
...@@ -44,7 +44,12 @@ def single_gpu_test(model, ...@@ -44,7 +44,12 @@ def single_gpu_test(model,
models_3d = (Base3DDetector, Base3DSegmentor, models_3d = (Base3DDetector, Base3DSegmentor,
SingleStageMono3DDetector) SingleStageMono3DDetector)
if isinstance(model.module, models_3d): if isinstance(model.module, models_3d):
model.module.show_results(data, result, out_dir=out_dir) model.module.show_results(
data,
result,
out_dir=out_dir,
show=show,
score_thr=show_score_thr)
# Visualize the results of MMDetection model # Visualize the results of MMDetection model
# 'show_result' is MMdetection visualization API # 'show_result' is MMdetection visualization API
else: else:
......
...@@ -78,7 +78,8 @@ def show_result(points, ...@@ -78,7 +78,8 @@ def show_result(points,
out_dir, out_dir,
filename, filename,
show=False, show=False,
snapshot=False): snapshot=False,
pred_labels=None):
"""Convert results into format that is directly readable for meshlab. """Convert results into format that is directly readable for meshlab.
Args: Args:
...@@ -87,8 +88,11 @@ def show_result(points, ...@@ -87,8 +88,11 @@ def show_result(points,
pred_bboxes (np.ndarray): Predicted boxes. pred_bboxes (np.ndarray): Predicted boxes.
out_dir (str): Path of output directory out_dir (str): Path of output directory
filename (str): Filename of the current frame. filename (str): Filename of the current frame.
show (bool): Visualize the results online. Defaults to False. show (bool, optional): Visualize the results online. Defaults to False.
snapshot (bool): Whether to save the online results. Defaults to False. snapshot (bool, optional): Whether to save the online results.
Defaults to False.
pred_labels (np.ndarray, optional): Predicted labels of boxes.
Defaults to None.
""" """
result_path = osp.join(out_dir, filename) result_path = osp.join(out_dir, filename)
mmcv.mkdir_or_exist(result_path) mmcv.mkdir_or_exist(result_path)
...@@ -98,7 +102,23 @@ def show_result(points, ...@@ -98,7 +102,23 @@ def show_result(points,
vis = Visualizer(points) vis = Visualizer(points)
if pred_bboxes is not None: if pred_bboxes is not None:
if pred_labels is None:
vis.add_bboxes(bbox3d=pred_bboxes) vis.add_bboxes(bbox3d=pred_bboxes)
else:
palette = np.random.randint(
0, 255, size=(pred_labels.max() + 1, 3)) / 256
labelDict = {}
for j in range(len(pred_labels)):
i = int(pred_labels[j].numpy())
if labelDict.get(i) is None:
labelDict[i] = []
labelDict[i].append(pred_bboxes[j])
for i in labelDict:
vis.add_bboxes(
bbox3d=np.array(labelDict[i]),
bbox_color=palette[i],
points_in_box_color=palette[i])
if gt_bboxes is not None: if gt_bboxes is not None:
vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1)) vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1))
show_path = osp.join(result_path, show_path = osp.join(result_path,
......
...@@ -60,13 +60,18 @@ class Base3DDetector(BaseDetector): ...@@ -60,13 +60,18 @@ class Base3DDetector(BaseDetector):
else: else:
return self.forward_test(**kwargs) return self.forward_test(**kwargs)
def show_results(self, data, result, out_dir): def show_results(self, data, result, out_dir, show=False, score_thr=None):
"""Results visualization. """Results visualization.
Args: Args:
data (list[dict]): Input points and the information of the sample. data (list[dict]): Input points and the information of the sample.
result (list[dict]): Prediction results. result (list[dict]): Prediction results.
out_dir (str): Output directory of visualization result. out_dir (str): Output directory of visualization result.
show (bool, optional): Determines whether you are
going to show result by open3d.
Defaults to False.
score_thr (float, optional): Score threshold of bounding boxes.
Default to None.
""" """
for batch_id in range(len(result)): for batch_id in range(len(result)):
if isinstance(data['points'][0], DC): if isinstance(data['points'][0], DC):
...@@ -93,6 +98,12 @@ class Base3DDetector(BaseDetector): ...@@ -93,6 +98,12 @@ class Base3DDetector(BaseDetector):
assert out_dir is not None, 'Expect out_dir, got none.' assert out_dir is not None, 'Expect out_dir, got none.'
pred_bboxes = result[batch_id]['boxes_3d'] pred_bboxes = result[batch_id]['boxes_3d']
pred_labels = result[batch_id]['labels_3d']
if score_thr is not None:
mask = result[batch_id]['scores_3d'] > score_thr
pred_bboxes = pred_bboxes[mask]
pred_labels = pred_labels[mask]
# for now we convert points and bbox into depth mode # for now we convert points and bbox into depth mode
if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d
...@@ -105,4 +116,11 @@ class Base3DDetector(BaseDetector): ...@@ -105,4 +116,11 @@ class Base3DDetector(BaseDetector):
ValueError( ValueError(
f'Unsupported box_mode_3d {box_mode_3d} for convertion!') f'Unsupported box_mode_3d {box_mode_3d} for convertion!')
pred_bboxes = pred_bboxes.tensor.cpu().numpy() pred_bboxes = pred_bboxes.tensor.cpu().numpy()
show_result(points, None, pred_bboxes, out_dir, file_name) show_result(
points,
None,
pred_bboxes,
out_dir,
file_name,
show=show,
pred_labels=pred_labels)
...@@ -178,13 +178,20 @@ class SingleStageMono3DDetector(SingleStageDetector): ...@@ -178,13 +178,20 @@ class SingleStageMono3DDetector(SingleStageDetector):
return [bbox_list] return [bbox_list]
def show_results(self, data, result, out_dir): def show_results(self, data, result, out_dir, show=False, score_thr=None):
"""Results visualization. """Results visualization.
Args: Args:
data (list[dict]): Input images and the information of the sample. data (list[dict]): Input images and the information of the sample.
result (list[dict]): Prediction results. result (list[dict]): Prediction results.
out_dir (str): Output directory of visualization result. out_dir (str): Output directory of visualization result.
show (bool, optional): Determines whether you are
going to show result by open3d.
Defaults to False.
TODO: implement score_thr of single_stage_mono3d.
score_thr (float, optional): Score threshold of bounding boxes.
Default to None.
Not implemented yet, but it is here for unification.
""" """
for batch_id in range(len(result)): for batch_id in range(len(result)):
if isinstance(data['img_metas'][0], DC): if isinstance(data['img_metas'][0], DC):
...@@ -215,4 +222,4 @@ class SingleStageMono3DDetector(SingleStageDetector): ...@@ -215,4 +222,4 @@ class SingleStageMono3DDetector(SingleStageDetector):
out_dir, out_dir,
file_name, file_name,
'camera', 'camera',
show=True) show=show)
...@@ -72,7 +72,9 @@ class Base3DSegmentor(BaseSegmentor): ...@@ -72,7 +72,9 @@ class Base3DSegmentor(BaseSegmentor):
result, result,
palette=None, palette=None,
out_dir=None, out_dir=None,
ignore_index=None): ignore_index=None,
show=False,
score_thr=None):
"""Results visualization. """Results visualization.
Args: Args:
...@@ -85,6 +87,13 @@ class Base3DSegmentor(BaseSegmentor): ...@@ -85,6 +87,13 @@ class Base3DSegmentor(BaseSegmentor):
ignore_index (int, optional): The label index to be ignored, e.g. ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES). unannotated points. If None is given, set to len(self.CLASSES).
Defaults to None. Defaults to None.
show (bool, optional): Determines whether you are
going to show result by open3d.
Defaults to False.
TODO: implement score_thr of Base3DSegmentor.
score_thr (float, optional): Score threshold of bounding boxes.
Default to None.
Not implemented yet, but it is here for unification.
""" """
assert out_dir is not None, 'Expect out_dir, got none.' assert out_dir is not None, 'Expect out_dir, got none.'
if palette is None: if palette is None:
...@@ -123,4 +132,4 @@ class Base3DSegmentor(BaseSegmentor): ...@@ -123,4 +132,4 @@ class Base3DSegmentor(BaseSegmentor):
file_name, file_name,
palette, palette,
ignore_index, ignore_index,
show=True) show=show)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment