Commit edb03937 authored by Bo Li's avatar Bo Li Committed by Kai Chen
Browse files

Added mask visualization part to inference part and add out_file interface. (#403)

* Update README.md

* Update inference.py

* Update README.md

* Update inference.py

Added mask visualization part for inferring.

* Update README.md

* Update inference.py

* Update inference.py

convert all tabs to spaces

* Update inference.py
parent a3c8ddf7
import mmcv import mmcv
import numpy as np import numpy as np
import pycocotools.mask as maskUtils
import torch import torch
from mmdet.core import get_classes
from mmdet.datasets import to_tensor from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform from mmdet.datasets.transforms import ImageTransform
from mmdet.core import get_classes
def _prepare_data(img, img_transform, cfg, device): def _prepare_data(img, img_transform, cfg, device):
...@@ -50,18 +51,33 @@ def inference_detector(model, imgs, cfg, device='cuda:0'): ...@@ -50,18 +51,33 @@ def inference_detector(model, imgs, cfg, device='cuda:0'):
return _inference_generator(model, imgs, img_transform, cfg, device) return _inference_generator(model, imgs, img_transform, cfg, device)
def show_result(img, result, dataset='coco', score_thr=0.3): def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None):
img = mmcv.imread(img)
class_names = get_classes(dataset) class_names = get_classes(dataset)
if isinstance(result, tuple):
bbox_result, segm_result = result
else:
bbox_result, segm_result = result, None
bboxes = np.vstack(bbox_result)
# draw segmentation masks
if segm_result is not None:
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
for i in inds:
color_mask = np.random.randint(
0, 256, (1, 3), dtype=np.uint8)
mask = maskUtils.decode(segms[i]).astype(np.bool)
img[mask] = img[mask] * 0.5 + color_mask * 0.5
# draw bounding boxes
labels = [ labels = [
np.full(bbox.shape[0], i, dtype=np.int32) np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result) for i, bbox in enumerate(bbox_result)
] ]
labels = np.concatenate(labels) labels = np.concatenate(labels)
bboxes = np.vstack(result)
img = mmcv.imread(img)
mmcv.imshow_det_bboxes( mmcv.imshow_det_bboxes(
img.copy(), img.copy(),
bboxes, bboxes,
labels, labels,
class_names=class_names, class_names=class_names,
score_thr=score_thr) score_thr=score_thr,
show=out_file is None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment