Commit 2e856c71 authored by Kai Chen's avatar Kai Chen
Browse files

allow mask visualization

parent 03e14ed2
......@@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
import torch.nn as nn
import pycocotools.mask as maskUtils
from mmdet.core import tensor2imgs, get_classes
......@@ -86,6 +87,11 @@ class BaseDetector(nn.Module):
img_norm_cfg,
dataset='coco',
score_thr=0.3):
if isinstance(result, tuple):
bbox_result, segm_result = result
else:
bbox_result, segm_result = result, None
img_tensor = data['img'][0]
img_metas = data['img_meta'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
......@@ -102,12 +108,23 @@ class BaseDetector(nn.Module):
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
bboxes = np.vstack(bbox_result)
# draw segmentation masks
if segm_result is not None:
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
for i in inds:
color_mask = np.random.randint(
0, 256, (1, 3), dtype=np.uint8)
mask = maskUtils.decode(segms[i]).astype(np.bool)
img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
# draw bounding boxes
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
mmcv.imshow_det_bboxes(
img_show,
bboxes,
......
......@@ -306,14 +306,13 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise NotImplementedError
def show_result(self, data, result, img_norm_cfg, **kwargs):
# TODO: show segmentation masks
if self.with_mask:
ms_bbox_result, ms_segm_result = result
if isinstance(ms_bbox_result, dict):
result = (ms_bbox_result['ensemble'],
ms_segm_result['ensemble'])
else:
ms_bbox_result = result
if isinstance(ms_bbox_result, dict):
bbox_result = ms_bbox_result['ensemble']
else:
bbox_result = ms_bbox_result
super(CascadeRCNN, self).show_result(data, bbox_result, img_norm_cfg,
if isinstance(result, dict):
result = result['ensemble']
super(CascadeRCNN, self).show_result(data, result, img_norm_cfg,
**kwargs)
......@@ -25,10 +25,3 @@ class MaskRCNN(TwoStageDetector):
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
def show_result(self, data, result, img_norm_cfg, **kwargs):
# TODO: show segmentation masks
assert isinstance(result, tuple)
assert len(result) == 2 # (bbox_results, segm_results)
super(MaskRCNN, self).show_result(data, result[0], img_norm_cfg,
**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment