Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
80854128
Unverified
Commit
80854128
authored
Sep 01, 2022
by
ChaimZhu
Committed by
GitHub
Sep 01, 2022
Browse files
[Feats]: update visualization hook (#1792)
parent
d0c6c5b2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
263 additions
and
2 deletions
+263
-2
configs/_base_/default_runtime.py
configs/_base_/default_runtime.py
+1
-1
mmdet3d/engine/hooks/__init__.py
mmdet3d/engine/hooks/__init__.py
+4
-0
mmdet3d/engine/hooks/visualization_hook.py
mmdet3d/engine/hooks/visualization_hook.py
+155
-0
tests/test_engine/test_hooks/test_visualization_hook.py
tests/test_engine/test_hooks/test_visualization_hook.py
+70
-0
tools/test.py
tools/test.py
+33
-1
No files found.
configs/_base_/default_runtime.py
View file @
80854128
...
@@ -6,7 +6,7 @@ default_hooks = dict(
...
@@ -6,7 +6,7 @@ default_hooks = dict(
param_scheduler
=
dict
(
type
=
'ParamSchedulerHook'
),
param_scheduler
=
dict
(
type
=
'ParamSchedulerHook'
),
checkpoint
=
dict
(
type
=
'CheckpointHook'
,
interval
=
1
),
checkpoint
=
dict
(
type
=
'CheckpointHook'
,
interval
=
1
),
sampler_seed
=
dict
(
type
=
'DistSamplerSeedHook'
),
sampler_seed
=
dict
(
type
=
'DistSamplerSeedHook'
),
)
visualization
=
dict
(
type
=
'Det3DVisualizationHook'
)
)
env_cfg
=
dict
(
env_cfg
=
dict
(
cudnn_benchmark
=
False
,
cudnn_benchmark
=
False
,
...
...
mmdet3d/engine/hooks/__init__.py
0 → 100644
View file @
80854128
# Copyright (c) OpenMMLab. All rights reserved.
from
visualization_hook
import
Det3DVisualizationHook
__all__
=
[
'Det3DVisualizationHook'
]
mmdet3d/engine/hooks/visualization_hook.py
0 → 100644
View file @
80854128
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
warnings
from
typing
import
Optional
,
Sequence
import
mmcv
from
mmengine.fileio
import
FileClient
from
mmengine.hooks
import
Hook
from
mmengine.runner
import
Runner
from
mmengine.utils
import
mkdir_or_exist
from
mmengine.visualization
import
Visualizer
from
mmdet3d.registry
import
HOOKS
from
mmdet3d.structures
import
Det3DDataSample
@
HOOKS
.
register_module
()
class
Det3DVisualizationHook
(
Hook
):
"""Detection Visualization Hook. Used to visualize validation and testing
process prediction results.
In the testing phase:
1. If ``show`` is True, it means that only the prediction results are
visualized without storing data, so ``vis_backends`` needs to
be excluded.
2. If ``test_out_dir`` is specified, it means that the prediction results
need to be saved to ``test_out_dir``. In order to avoid vis_backends
also storing data, so ``vis_backends`` needs to be excluded.
3. ``vis_backends`` takes effect if the user does not specify ``show``
and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or
TensorboardVisBackend to store the prediction result in Wandb or
Tensorboard.
Args:
draw (bool): whether to draw prediction results. If it is False,
it means that no drawing will be done. Defaults to False.
interval (int): The interval of visualization. Defaults to 50.
score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
test_out_dir (str, optional): directory where painted images
will be saved in testing process.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def
__init__
(
self
,
draw
:
bool
=
False
,
interval
:
int
=
50
,
score_thr
:
float
=
0.3
,
show
:
bool
=
False
,
wait_time
:
float
=
0.
,
test_out_dir
:
Optional
[
str
]
=
None
,
file_client_args
:
dict
=
dict
(
backend
=
'disk'
)):
self
.
_visualizer
:
Visualizer
=
Visualizer
.
get_current_instance
()
self
.
interval
=
interval
self
.
score_thr
=
score_thr
self
.
show
=
show
if
self
.
show
:
# No need to think about vis backends.
self
.
_visualizer
.
_vis_backends
=
{}
warnings
.
warn
(
'The show is True, it means that only '
'the prediction results are visualized '
'without storing data, so vis_backends '
'needs to be excluded.'
)
self
.
wait_time
=
wait_time
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
None
self
.
draw
=
draw
self
.
test_out_dir
=
test_out_dir
self
.
_test_index
=
0
def
after_val_iter
(
self
,
runner
:
Runner
,
batch_idx
:
int
,
data_batch
:
dict
,
outputs
:
Sequence
[
Det3DDataSample
])
->
None
:
"""Run after every ``self.interval`` validation iterations.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
that contain annotations and predictions.
"""
if
self
.
draw
is
False
:
return
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
**
self
.
file_client_args
)
# There is no guarantee that the same batch of images
# is visualized for each evaluation.
total_curr_iter
=
runner
.
iter
+
batch_idx
# Visualize only the first data
img_path
=
outputs
[
0
].
img_path
img_bytes
=
self
.
file_client
.
get
(
img_path
)
img
=
mmcv
.
imfrombytes
(
img_bytes
,
channel_order
=
'rgb'
)
if
total_curr_iter
%
self
.
interval
==
0
:
self
.
_visualizer
.
add_datasample
(
osp
.
basename
(
img_path
)
if
self
.
show
else
'val_img'
,
img
,
data_sample
=
outputs
[
0
],
show
=
self
.
show
,
wait_time
=
self
.
wait_time
,
pred_score_thr
=
self
.
score_thr
,
step
=
total_curr_iter
)
def
after_test_iter
(
self
,
runner
:
Runner
,
batch_idx
:
int
,
data_batch
:
dict
,
outputs
:
Sequence
[
Det3DDataSample
])
->
None
:
"""Run after every testing iterations.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
that contain annotations and predictions.
"""
if
self
.
draw
is
False
:
return
if
self
.
test_out_dir
is
not
None
:
self
.
test_out_dir
=
osp
.
join
(
runner
.
work_dir
,
runner
.
timestamp
,
self
.
test_out_dir
)
mkdir_or_exist
(
self
.
test_out_dir
)
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
**
self
.
file_client_args
)
for
data_sample
in
outputs
:
self
.
_test_index
+=
1
img_path
=
data_sample
.
img_path
img_bytes
=
self
.
file_client
.
get
(
img_path
)
img
=
mmcv
.
imfrombytes
(
img_bytes
,
channel_order
=
'rgb'
)
out_file
=
None
if
self
.
test_out_dir
is
not
None
:
out_file
=
osp
.
basename
(
img_path
)
out_file
=
osp
.
join
(
self
.
test_out_dir
,
out_file
)
self
.
_visualizer
.
add_datasample
(
osp
.
basename
(
img_path
)
if
self
.
show
else
'test_img'
,
img
,
data_sample
=
data_sample
,
show
=
self
.
show
,
wait_time
=
self
.
wait_time
,
pred_score_thr
=
self
.
score_thr
,
out_file
=
out_file
,
step
=
self
.
_test_index
)
tests/test_engine/test_hooks/test_visualization_hook.py
0 → 100644
View file @
80854128
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
shutil
import
time
from
unittest
import
TestCase
from
unittest.mock
import
Mock
import
torch
from
mmengine.structures
import
InstanceData
from
mmdet3d.engine.hooks
import
Det3DVisualizationHook
from
mmdet3d.structures
import
Det3DDataSample
from
mmdet3d.visualization
import
Det3DLocalVisualizer
def
_rand_bboxes
(
num_boxes
,
h
,
w
):
cx
,
cy
,
bw
,
bh
=
torch
.
rand
(
num_boxes
,
4
).
T
tl_x
=
((
cx
*
w
)
-
(
w
*
bw
/
2
)).
clip
(
0
,
w
)
tl_y
=
((
cy
*
h
)
-
(
h
*
bh
/
2
)).
clip
(
0
,
h
)
br_x
=
((
cx
*
w
)
+
(
w
*
bw
/
2
)).
clip
(
0
,
w
)
br_y
=
((
cy
*
h
)
+
(
h
*
bh
/
2
)).
clip
(
0
,
h
)
bboxes
=
torch
.
vstack
([
tl_x
,
tl_y
,
br_x
,
br_y
]).
T
return
bboxes
class
TestVisualizationHook
(
TestCase
):
def
setUp
(
self
)
->
None
:
Det3DLocalVisualizer
.
get_instance
(
'visualizer'
)
pred_instances
=
InstanceData
()
pred_instances
.
bboxes
=
_rand_bboxes
(
5
,
10
,
12
)
pred_instances
.
labels
=
torch
.
randint
(
0
,
2
,
(
5
,
))
pred_instances
.
scores
=
torch
.
rand
((
5
,
))
pred_det_data_sample
=
Det3DDataSample
()
pred_det_data_sample
.
set_metainfo
({
'img_path'
:
osp
.
join
(
osp
.
dirname
(
__file__
),
'../../data/color.jpg'
)
})
pred_det_data_sample
.
pred_instances
=
pred_instances
self
.
outputs
=
[
pred_det_data_sample
]
*
2
def
test_after_val_iter
(
self
):
runner
=
Mock
()
runner
.
iter
=
1
hook
=
Det3DVisualizationHook
()
hook
.
after_val_iter
(
runner
,
1
,
{},
self
.
outputs
)
def
test_after_test_iter
(
self
):
runner
=
Mock
()
runner
.
iter
=
1
hook
=
Det3DVisualizationHook
(
draw
=
True
)
hook
.
after_test_iter
(
runner
,
1
,
{},
self
.
outputs
)
self
.
assertEqual
(
hook
.
_test_index
,
2
)
# test
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
test_out_dir
=
timestamp
+
'1'
runner
.
work_dir
=
timestamp
runner
.
timestamp
=
'1'
hook
=
Det3DVisualizationHook
(
draw
=
False
,
test_out_dir
=
test_out_dir
)
hook
.
after_test_iter
(
runner
,
1
,
{},
self
.
outputs
)
self
.
assertTrue
(
not
osp
.
exists
(
f
'
{
timestamp
}
/1/
{
test_out_dir
}
'
))
hook
=
Det3DVisualizationHook
(
draw
=
True
,
test_out_dir
=
test_out_dir
)
hook
.
after_test_iter
(
runner
,
1
,
{},
self
.
outputs
)
self
.
assertTrue
(
osp
.
exists
(
f
'
{
timestamp
}
/1/
{
test_out_dir
}
'
))
shutil
.
rmtree
(
f
'
{
timestamp
}
'
)
tools/test.py
View file @
80854128
...
@@ -10,7 +10,7 @@ from mmengine.runner import Runner
...
@@ -10,7 +10,7 @@ from mmengine.runner import Runner
from
mmdet3d.utils
import
register_all_modules
,
replace_ceph_backend
from
mmdet3d.utils
import
register_all_modules
,
replace_ceph_backend
# TODO: support fuse_conv_bn
, visualization,
and format_only
# TODO: support fuse_conv_bn and format_only
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet3D test (and eval) a model'
)
description
=
'MMDet3D test (and eval) a model'
)
...
@@ -19,6 +19,15 @@ def parse_args():
...
@@ -19,6 +19,15 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
'--work-dir'
,
'--work-dir'
,
help
=
'the directory to save the file containing evaluation metrics'
)
help
=
'the directory to save the file containing evaluation metrics'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show prediction results'
)
parser
.
add_argument
(
'--show-dir'
,
help
=
'directory where painted images will be saved. '
'If specified, it will be automatically saved '
'to the work_dir/timestamp/show_dir'
)
parser
.
add_argument
(
'--wait-time'
,
type
=
float
,
default
=
2
,
help
=
'the interval of show (s)'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--cfg-options'
,
'--cfg-options'
,
nargs
=
'+'
,
nargs
=
'+'
,
...
@@ -41,6 +50,26 @@ def parse_args():
...
@@ -41,6 +50,26 @@ def parse_args():
return
args
return
args
def
trigger_visualization_hook
(
cfg
,
args
):
default_hooks
=
cfg
.
default_hooks
if
'visualization'
in
default_hooks
:
visualization_hook
=
default_hooks
[
'visualization'
]
# Turn on visualization
visualization_hook
[
'draw'
]
=
True
if
args
.
show
:
visualization_hook
[
'show'
]
=
True
visualization_hook
[
'wait_time'
]
=
args
.
wait_time
if
args
.
show_dir
:
visualization_hook
[
'test_out_dir'
]
=
args
.
show_dir
else
:
raise
RuntimeError
(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=
\'
VisualizationHook
\'
)"'
)
return
cfg
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
...
@@ -68,6 +97,9 @@ def main():
...
@@ -68,6 +97,9 @@ def main():
cfg
.
load_from
=
args
.
checkpoint
cfg
.
load_from
=
args
.
checkpoint
if
args
.
show
or
args
.
show_dir
:
cfg
=
trigger_visualization_hook
(
cfg
,
args
)
# build the runner from config
# build the runner from config
if
'runner_type'
not
in
cfg
:
if
'runner_type'
not
in
cfg
:
# build the default runner
# build the default runner
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment