Unverified Commit ef6e0aa2 authored by Timothy's avatar Timothy Committed by GitHub
Browse files

[Fix] Fix missing parameter in `Det3DVisualizationHook` (#2118)



* Update visualization_hook.py

* Update test.py

* Update visualization_hook.py

* fix lint issues
Co-authored-by: default avatarshanmo <shanmo1412@gmail.com>
parent 68ef1d79
...@@ -40,6 +40,7 @@ class Det3DVisualizationHook(Hook): ...@@ -40,6 +40,7 @@ class Det3DVisualizationHook(Hook):
score_thr (float): The threshold to visualize the bboxes score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3. and masks. Defaults to 0.3.
show (bool): Whether to display the drawn image. Default to False. show (bool): Whether to display the drawn image. Default to False.
vis_task (str): Visualization task. Defaults to 'mono_det'.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
test_out_dir (str, optional): directory where painted images test_out_dir (str, optional): directory where painted images
will be saved in testing process. will be saved in testing process.
...@@ -53,6 +54,7 @@ class Det3DVisualizationHook(Hook): ...@@ -53,6 +54,7 @@ class Det3DVisualizationHook(Hook):
interval: int = 50, interval: int = 50,
score_thr: float = 0.3, score_thr: float = 0.3,
show: bool = False, show: bool = False,
vis_task: str = 'mono_det',
wait_time: float = 0., wait_time: float = 0.,
test_out_dir: Optional[str] = None, test_out_dir: Optional[str] = None,
file_client_args: dict = dict(backend='disk')): file_client_args: dict = dict(backend='disk')):
...@@ -67,6 +69,7 @@ class Det3DVisualizationHook(Hook): ...@@ -67,6 +69,7 @@ class Det3DVisualizationHook(Hook):
'the prediction results are visualized ' 'the prediction results are visualized '
'without storing data, so vis_backends ' 'without storing data, so vis_backends '
'needs to be excluded.') 'needs to be excluded.')
self.vis_task = vis_task
self.wait_time = wait_time self.wait_time = wait_time
self.file_client_args = file_client_args.copy() self.file_client_args = file_client_args.copy()
...@@ -119,6 +122,7 @@ class Det3DVisualizationHook(Hook): ...@@ -119,6 +122,7 @@ class Det3DVisualizationHook(Hook):
data_input, data_input,
data_sample=outputs[0], data_sample=outputs[0],
show=self.show, show=self.show,
vis_task=self.vis_task,
wait_time=self.wait_time, wait_time=self.wait_time,
pred_score_thr=self.score_thr, pred_score_thr=self.score_thr,
step=total_curr_iter) step=total_curr_iter)
...@@ -173,6 +177,7 @@ class Det3DVisualizationHook(Hook): ...@@ -173,6 +177,7 @@ class Det3DVisualizationHook(Hook):
data_input, data_input,
data_sample=data_sample, data_sample=data_sample,
show=self.show, show=self.show,
vis_task=self.vis_task,
wait_time=self.wait_time, wait_time=self.wait_time,
pred_score_thr=self.score_thr, pred_score_thr=self.score_thr,
out_file=out_file, out_file=out_file,
......
...@@ -28,6 +28,14 @@ def parse_args(): ...@@ -28,6 +28,14 @@ def parse_args():
help='directory where painted images will be saved. ' help='directory where painted images will be saved. '
'If specified, it will be automatically saved ' 'If specified, it will be automatically saved '
'to the work_dir/timestamp/show_dir') 'to the work_dir/timestamp/show_dir')
parser.add_argument(
'--task',
type=str,
choices=[
'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg',
'multi-modality_det'
],
help='Determine the visualization method depending on the task.')
parser.add_argument( parser.add_argument(
'--wait-time', type=float, default=2, help='the interval of show (s)') '--wait-time', type=float, default=2, help='the interval of show (s)')
parser.add_argument( parser.add_argument(
...@@ -63,6 +71,7 @@ def trigger_visualization_hook(cfg, args): ...@@ -63,6 +71,7 @@ def trigger_visualization_hook(cfg, args):
visualization_hook['wait_time'] = args.wait_time visualization_hook['wait_time'] = args.wait_time
if args.show_dir: if args.show_dir:
visualization_hook['test_out_dir'] = args.show_dir visualization_hook['test_out_dir'] = args.show_dir
visualization_hook['vis_task'] = args.task
else: else:
raise RuntimeError( raise RuntimeError(
'VisualizationHook must be included in default_hooks.' 'VisualizationHook must be included in default_hooks.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment