Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
28b3c6ea
Commit
28b3c6ea
authored
Mar 31, 2020
by
WXinlong
Browse files
add quick demo for inference
parent
22d25bed
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
102 additions
and
18 deletions
+102
-18
demo/coco_test_12510.jpg
demo/coco_test_12510.jpg
+0
-0
demo/inference_demo.py
demo/inference_demo.py
+16
-0
mmdet/apis/__init__.py
mmdet/apis/__init__.py
+2
-2
mmdet/apis/inference.py
mmdet/apis/inference.py
+84
-0
tools/test_ins_vis.py
tools/test_ins_vis.py
+0
-16
No files found.
demo/coco_test_12510.jpg
deleted
100644 → 0
View file @
22d25bed
179 KB
demo/inference_demo.py
0 → 100644
View file @
28b3c6ea
from
mmdet.apis
import
init_detector
,
inference_detector
,
show_result_pyplot
,
show_result_ins
import
mmcv
config_file
=
'../configs/solo/decoupled_solo_r50_fpn_8gpu_3x.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file
=
'../checkpoints/DECOUPLED_SOLO_R50_3x.pth'
# build the model from a config file and a checkpoint file
model
=
init_detector
(
config_file
,
checkpoint_file
,
device
=
'cuda:0'
)
# test a single image
img
=
'demo.jpg'
result
=
inference_detector
(
model
,
img
)
show_result_ins
(
img
,
result
,
model
.
CLASSES
,
score_thr
=
0.25
,
out_file
=
"demo_out.jpg"
)
mmdet/apis/__init__.py
View file @
28b3c6ea
from
.inference
import
(
async_inference_detector
,
inference_detector
,
from
.inference
import
(
async_inference_detector
,
inference_detector
,
init_detector
,
show_result
,
show_result_pyplot
)
init_detector
,
show_result
,
show_result_pyplot
,
show_result_ins
)
from
.train
import
get_root_logger
,
set_random_seed
,
train_detector
from
.train
import
get_root_logger
,
set_random_seed
,
train_detector
__all__
=
[
__all__
=
[
'get_root_logger'
,
'set_random_seed'
,
'train_detector'
,
'init_detector'
,
'get_root_logger'
,
'set_random_seed'
,
'train_detector'
,
'init_detector'
,
'async_inference_detector'
,
'inference_detector'
,
'show_result'
,
'async_inference_detector'
,
'inference_detector'
,
'show_result'
,
'show_result_pyplot'
'show_result_pyplot'
,
'show_result_ins'
]
]
mmdet/apis/inference.py
View file @
28b3c6ea
...
@@ -12,6 +12,8 @@ from mmdet.core import get_classes
...
@@ -12,6 +12,8 @@ from mmdet.core import get_classes
from
mmdet.datasets.pipelines
import
Compose
from
mmdet.datasets.pipelines
import
Compose
from
mmdet.models
import
build_detector
from
mmdet.models
import
build_detector
import
cv2
from
scipy
import
ndimage
def
init_detector
(
config
,
checkpoint
=
None
,
device
=
'cuda:0'
):
def
init_detector
(
config
,
checkpoint
=
None
,
device
=
'cuda:0'
):
"""Initialize a detector from config file.
"""Initialize a detector from config file.
...
@@ -202,3 +204,85 @@ def show_result_pyplot(img,
...
@@ -202,3 +204,85 @@ def show_result_pyplot(img,
img
,
result
,
class_names
,
score_thr
=
score_thr
,
show
=
False
)
img
,
result
,
class_names
,
score_thr
=
score_thr
,
show
=
False
)
plt
.
figure
(
figsize
=
fig_size
)
plt
.
figure
(
figsize
=
fig_size
)
plt
.
imshow
(
mmcv
.
bgr2rgb
(
img
))
plt
.
imshow
(
mmcv
.
bgr2rgb
(
img
))
def
show_result_ins
(
img
,
result
,
class_names
,
score_thr
=
0.3
,
sort_by_density
=
False
,
out_file
=
None
):
"""Visualize the instance segmentation results on the image.
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The instance segmentation result.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the masks.
sort_by_density (bool): sort the masks by their density.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
Returns:
np.ndarray or None: If neither `show` nor `out_file` is specified, the
visualized image is returned, otherwise None is returned.
"""
assert
isinstance
(
class_names
,
(
tuple
,
list
))
img
=
mmcv
.
imread
(
img
)
img_show
=
img
.
copy
()
h
,
w
,
_
=
img
.
shape
cur_result
=
result
[
0
]
seg_label
=
cur_result
[
0
]
seg_label
=
seg_label
.
cpu
().
numpy
().
astype
(
np
.
uint8
)
cate_label
=
cur_result
[
1
]
cate_label
=
cate_label
.
cpu
().
numpy
()
score
=
cur_result
[
2
].
cpu
().
numpy
()
vis_inds
=
score
>
score_thr
seg_label
=
seg_label
[
vis_inds
]
num_mask
=
seg_label
.
shape
[
0
]
cate_label
=
cate_label
[
vis_inds
]
cate_score
=
score
[
vis_inds
]
if
sort_by_density
:
mask_density
=
[]
for
idx
in
range
(
num_mask
):
cur_mask
=
seg_label
[
idx
,
:,
:]
cur_mask
=
mmcv
.
imresize
(
cur_mask
,
(
w
,
h
))
cur_mask
=
(
cur_mask
>
0.5
).
astype
(
np
.
int32
)
mask_density
.
append
(
cur_mask
.
sum
())
orders
=
np
.
argsort
(
mask_density
)
seg_label
=
seg_label
[
orders
]
cate_label
=
cate_label
[
orders
]
cate_score
=
cate_score
[
orders
]
np
.
random
.
seed
(
42
)
color_masks
=
[
np
.
random
.
randint
(
0
,
256
,
(
1
,
3
),
dtype
=
np
.
uint8
)
for
_
in
range
(
num_mask
)
]
for
idx
in
range
(
num_mask
):
idx
=
-
(
idx
+
1
)
cur_mask
=
seg_label
[
idx
,
:,
:]
cur_mask
=
mmcv
.
imresize
(
cur_mask
,
(
w
,
h
))
cur_mask
=
(
cur_mask
>
0.5
).
astype
(
np
.
uint8
)
if
cur_mask
.
sum
()
==
0
:
continue
color_mask
=
color_masks
[
idx
]
cur_mask_bool
=
cur_mask
.
astype
(
np
.
bool
)
img_show
[
cur_mask_bool
]
=
img
[
cur_mask_bool
]
*
0.5
+
color_mask
*
0.5
cur_cate
=
cate_label
[
idx
]
cur_score
=
cate_score
[
idx
]
label_text
=
class_names
[
cur_cate
]
#label_text += '|{:.02f}'.format(cur_score)
center_y
,
center_x
=
ndimage
.
measurements
.
center_of_mass
(
cur_mask
)
vis_pos
=
(
max
(
int
(
center_x
)
-
10
,
0
),
int
(
center_y
))
cv2
.
putText
(
img_show
,
label_text
,
vis_pos
,
cv2
.
FONT_HERSHEY_COMPLEX
,
0.3
,
(
255
,
255
,
255
))
# green
if
out_file
is
None
:
return
img
else
:
mmcv
.
imwrite
(
img_show
,
out_file
)
tools/test_ins_vis.py
View file @
28b3c6ea
...
@@ -33,10 +33,8 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
...
@@ -33,10 +33,8 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
seg_label
=
cur_result
[
0
]
seg_label
=
cur_result
[
0
]
seg_label
=
seg_label
.
cpu
().
numpy
().
astype
(
np
.
uint8
)
seg_label
=
seg_label
.
cpu
().
numpy
().
astype
(
np
.
uint8
)
cate_label
=
cur_result
[
1
]
cate_label
=
cur_result
[
1
]
cate_label
=
cate_label
.
cpu
().
numpy
()
cate_label
=
cate_label
.
cpu
().
numpy
()
score
=
cur_result
[
2
].
cpu
().
numpy
()
score
=
cur_result
[
2
].
cpu
().
numpy
()
vis_inds
=
score
>
score_thr
vis_inds
=
score
>
score_thr
...
@@ -51,7 +49,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
...
@@ -51,7 +49,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
cur_mask
=
mmcv
.
imresize
(
cur_mask
,
(
w
,
h
))
cur_mask
=
mmcv
.
imresize
(
cur_mask
,
(
w
,
h
))
cur_mask
=
(
cur_mask
>
0.5
).
astype
(
np
.
int32
)
cur_mask
=
(
cur_mask
>
0.5
).
astype
(
np
.
int32
)
mask_density
.
append
(
cur_mask
.
sum
())
mask_density
.
append
(
cur_mask
.
sum
())
orders
=
np
.
argsort
(
mask_density
)
orders
=
np
.
argsort
(
mask_density
)
seg_label
=
seg_label
[
orders
]
seg_label
=
seg_label
[
orders
]
cate_label
=
cate_label
[
orders
]
cate_label
=
cate_label
[
orders
]
...
@@ -63,25 +60,13 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
...
@@ -63,25 +60,13 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
cur_mask
=
seg_label
[
idx
,
:,:]
cur_mask
=
seg_label
[
idx
,
:,:]
cur_mask
=
mmcv
.
imresize
(
cur_mask
,
(
w
,
h
))
cur_mask
=
mmcv
.
imresize
(
cur_mask
,
(
w
,
h
))
cur_mask
=
(
cur_mask
>
0.5
).
astype
(
np
.
uint8
)
cur_mask
=
(
cur_mask
>
0.5
).
astype
(
np
.
uint8
)
if
cur_mask
.
sum
()
==
0
:
if
cur_mask
.
sum
()
==
0
:
continue
continue
color_mask
=
np
.
random
.
randint
(
color_mask
=
np
.
random
.
randint
(
0
,
256
,
(
1
,
3
),
dtype
=
np
.
uint8
)
0
,
256
,
(
1
,
3
),
dtype
=
np
.
uint8
)
cur_mask_bool
=
cur_mask
.
astype
(
np
.
bool
)
cur_mask_bool
=
cur_mask
.
astype
(
np
.
bool
)
seg_show
[
cur_mask_bool
]
=
img_show
[
cur_mask_bool
]
*
0.5
+
color_mask
*
0.5
seg_show
[
cur_mask_bool
]
=
img_show
[
cur_mask_bool
]
*
0.5
+
color_mask
*
0.5
for
idx
in
range
(
num_mask
):
idx
=
-
(
idx
+
1
)
cur_mask
=
seg_label
[
idx
,
:,
:]
cur_mask
=
mmcv
.
imresize
(
cur_mask
,
(
w
,
h
))
cur_mask
=
(
cur_mask
>
0.5
).
astype
(
np
.
uint8
)
if
cur_mask
.
sum
()
==
0
:
continue
cur_cate
=
cate_label
[
idx
]
cur_cate
=
cate_label
[
idx
]
cur_score
=
cate_score
[
idx
]
cur_score
=
cate_score
[
idx
]
...
@@ -92,7 +77,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
...
@@ -92,7 +77,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
vis_pos
=
(
max
(
int
(
center_x
)
-
10
,
0
),
int
(
center_y
))
vis_pos
=
(
max
(
int
(
center_x
)
-
10
,
0
),
int
(
center_y
))
cv2
.
putText
(
seg_show
,
label_text
,
vis_pos
,
cv2
.
putText
(
seg_show
,
label_text
,
vis_pos
,
cv2
.
FONT_HERSHEY_COMPLEX
,
0.3
,
(
255
,
255
,
255
))
# green
cv2
.
FONT_HERSHEY_COMPLEX
,
0.3
,
(
255
,
255
,
255
))
# green
mmcv
.
imwrite
(
seg_show
,
'{}/{}.jpg'
.
format
(
save_dir
,
data_id
))
mmcv
.
imwrite
(
seg_show
,
'{}/{}.jpg'
.
format
(
save_dir
,
data_id
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment