Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
65a2e5ea
Unverified
Commit
65a2e5ea
authored
Dec 05, 2018
by
Kai Chen
Committed by
GitHub
Dec 05, 2018
Browse files
Merge pull request #143 from hellock/mask-vis
Allow mask visualization
parents
a6ee0532
2e856c71
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
16 deletions
+25
-16
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+19
-2
mmdet/models/detectors/cascade_rcnn.py
mmdet/models/detectors/cascade_rcnn.py
+6
-7
mmdet/models/detectors/mask_rcnn.py
mmdet/models/detectors/mask_rcnn.py
+0
-7
No files found.
mmdet/models/detectors/base.py
View file @
65a2e5ea
...
...
@@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod
import
mmcv
import
numpy
as
np
import
torch.nn
as
nn
import
pycocotools.mask
as
maskUtils
from
mmdet.core
import
tensor2imgs
,
get_classes
...
...
@@ -86,6 +87,11 @@ class BaseDetector(nn.Module):
img_norm_cfg
,
dataset
=
'coco'
,
score_thr
=
0.3
):
if
isinstance
(
result
,
tuple
):
bbox_result
,
segm_result
=
result
else
:
bbox_result
,
segm_result
=
result
,
None
img_tensor
=
data
[
'img'
][
0
]
img_metas
=
data
[
'img_meta'
][
0
].
data
[
0
]
imgs
=
tensor2imgs
(
img_tensor
,
**
img_norm_cfg
)
...
...
@@ -102,12 +108,23 @@ class BaseDetector(nn.Module):
for
img
,
img_meta
in
zip
(
imgs
,
img_metas
):
h
,
w
,
_
=
img_meta
[
'img_shape'
]
img_show
=
img
[:
h
,
:
w
,
:]
bboxes
=
np
.
vstack
(
bbox_result
)
# draw segmentation masks
if
segm_result
is
not
None
:
segms
=
mmcv
.
concat_list
(
segm_result
)
inds
=
np
.
where
(
bboxes
[:,
-
1
]
>
score_thr
)[
0
]
for
i
in
inds
:
color_mask
=
np
.
random
.
randint
(
0
,
256
,
(
1
,
3
),
dtype
=
np
.
uint8
)
mask
=
maskUtils
.
decode
(
segms
[
i
]).
astype
(
np
.
bool
)
img_show
[
mask
]
=
img_show
[
mask
]
*
0.5
+
color_mask
*
0.5
# draw bounding boxes
labels
=
[
np
.
full
(
bbox
.
shape
[
0
],
i
,
dtype
=
np
.
int32
)
for
i
,
bbox
in
enumerate
(
result
)
for
i
,
bbox
in
enumerate
(
bbox_
result
)
]
labels
=
np
.
concatenate
(
labels
)
bboxes
=
np
.
vstack
(
result
)
mmcv
.
imshow_det_bboxes
(
img_show
,
bboxes
,
...
...
mmdet/models/detectors/cascade_rcnn.py
View file @
65a2e5ea
...
...
@@ -306,14 +306,13 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise
NotImplementedError
def
show_result
(
self
,
data
,
result
,
img_norm_cfg
,
**
kwargs
):
# TODO: show segmentation masks
if
self
.
with_mask
:
ms_bbox_result
,
ms_segm_result
=
result
if
isinstance
(
ms_bbox_result
,
dict
):
result
=
(
ms_bbox_result
[
'ensemble'
],
ms_segm_result
[
'ensemble'
])
else
:
ms_bbox_result
=
result
if
isinstance
(
ms_bbox_result
,
dict
):
bbox_result
=
ms_bbox_result
[
'ensemble'
]
else
:
bbox_result
=
ms_bbox_result
super
(
CascadeRCNN
,
self
).
show_result
(
data
,
bbox_result
,
img_norm_cfg
,
if
isinstance
(
result
,
dict
):
result
=
result
[
'ensemble'
]
super
(
CascadeRCNN
,
self
).
show_result
(
data
,
result
,
img_norm_cfg
,
**
kwargs
)
mmdet/models/detectors/mask_rcnn.py
View file @
65a2e5ea
...
...
@@ -25,10 +25,3 @@ class MaskRCNN(TwoStageDetector):
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
)
def
show_result
(
self
,
data
,
result
,
img_norm_cfg
,
**
kwargs
):
# TODO: show segmentation masks
assert
isinstance
(
result
,
tuple
)
assert
len
(
result
)
==
2
# (bbox_results, segm_results)
super
(
MaskRCNN
,
self
).
show_result
(
data
,
result
[
0
],
img_norm_cfg
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment