Commit 28b3c6ea authored by WXinlong's avatar WXinlong
Browse files

add quick demo for inference

parent 22d25bed
from mmdet.apis import init_detector, inference_detector, show_result_pyplot, show_result_ins
import mmcv
config_file = '../configs/solo/decoupled_solo_r50_fpn_8gpu_3x.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = '../checkpoints/DECOUPLED_SOLO_R50_3x.pth'
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# test a single image
img = 'demo.jpg'
result = inference_detector(model, img)
show_result_ins(img, result, model.CLASSES, score_thr=0.25, out_file="demo_out.jpg")
from .inference import (async_inference_detector, inference_detector, from .inference import (async_inference_detector, inference_detector,
init_detector, show_result, show_result_pyplot) init_detector, show_result, show_result_pyplot, show_result_ins)
from .train import get_root_logger, set_random_seed, train_detector from .train import get_root_logger, set_random_seed, train_detector
__all__ = [ __all__ = [
'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector', 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
'async_inference_detector', 'inference_detector', 'show_result', 'async_inference_detector', 'inference_detector', 'show_result',
'show_result_pyplot' 'show_result_pyplot', 'show_result_ins'
] ]
...@@ -12,6 +12,8 @@ from mmdet.core import get_classes ...@@ -12,6 +12,8 @@ from mmdet.core import get_classes
from mmdet.datasets.pipelines import Compose from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector from mmdet.models import build_detector
import cv2
from scipy import ndimage
def init_detector(config, checkpoint=None, device='cuda:0'): def init_detector(config, checkpoint=None, device='cuda:0'):
"""Initialize a detector from config file. """Initialize a detector from config file.
...@@ -202,3 +204,85 @@ def show_result_pyplot(img, ...@@ -202,3 +204,85 @@ def show_result_pyplot(img,
img, result, class_names, score_thr=score_thr, show=False) img, result, class_names, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size) plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img)) plt.imshow(mmcv.bgr2rgb(img))
def show_result_ins(img,
result,
class_names,
score_thr=0.3,
sort_by_density=False,
out_file=None):
"""Visualize the instance segmentation results on the image.
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The instance segmentation result.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the masks.
sort_by_density (bool): sort the masks by their density.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
Returns:
np.ndarray or None: If neither `show` nor `out_file` is specified, the
visualized image is returned, otherwise None is returned.
"""
assert isinstance(class_names, (tuple, list))
img = mmcv.imread(img)
img_show = img.copy()
h, w, _ = img.shape
cur_result = result[0]
seg_label = cur_result[0]
seg_label = seg_label.cpu().numpy().astype(np.uint8)
cate_label = cur_result[1]
cate_label = cate_label.cpu().numpy()
score = cur_result[2].cpu().numpy()
vis_inds = score > score_thr
seg_label = seg_label[vis_inds]
num_mask = seg_label.shape[0]
cate_label = cate_label[vis_inds]
cate_score = score[vis_inds]
if sort_by_density:
mask_density = []
for idx in range(num_mask):
cur_mask = seg_label[idx, :, :]
cur_mask = mmcv.imresize(cur_mask, (w, h))
cur_mask = (cur_mask > 0.5).astype(np.int32)
mask_density.append(cur_mask.sum())
orders = np.argsort(mask_density)
seg_label = seg_label[orders]
cate_label = cate_label[orders]
cate_score = cate_score[orders]
np.random.seed(42)
color_masks = [
np.random.randint(0, 256, (1, 3), dtype=np.uint8)
for _ in range(num_mask)
]
for idx in range(num_mask):
idx = -(idx+1)
cur_mask = seg_label[idx, :, :]
cur_mask = mmcv.imresize(cur_mask, (w, h))
cur_mask = (cur_mask > 0.5).astype(np.uint8)
if cur_mask.sum() == 0:
continue
color_mask = color_masks[idx]
cur_mask_bool = cur_mask.astype(np.bool)
img_show[cur_mask_bool] = img[cur_mask_bool] * 0.5 + color_mask * 0.5
cur_cate = cate_label[idx]
cur_score = cate_score[idx]
label_text = class_names[cur_cate]
#label_text += '|{:.02f}'.format(cur_score)
center_y, center_x = ndimage.measurements.center_of_mass(cur_mask)
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(img_show, label_text, vis_pos,
cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255)) # green
if out_file is None:
return img
else:
mmcv.imwrite(img_show, out_file)
...@@ -33,10 +33,8 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir): ...@@ -33,10 +33,8 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
seg_label = cur_result[0] seg_label = cur_result[0]
seg_label = seg_label.cpu().numpy().astype(np.uint8) seg_label = seg_label.cpu().numpy().astype(np.uint8)
cate_label = cur_result[1] cate_label = cur_result[1]
cate_label = cate_label.cpu().numpy() cate_label = cate_label.cpu().numpy()
score = cur_result[2].cpu().numpy() score = cur_result[2].cpu().numpy()
vis_inds = score > score_thr vis_inds = score > score_thr
...@@ -51,7 +49,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir): ...@@ -51,7 +49,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
cur_mask = mmcv.imresize(cur_mask, (w, h)) cur_mask = mmcv.imresize(cur_mask, (w, h))
cur_mask = (cur_mask > 0.5).astype(np.int32) cur_mask = (cur_mask > 0.5).astype(np.int32)
mask_density.append(cur_mask.sum()) mask_density.append(cur_mask.sum())
orders = np.argsort(mask_density) orders = np.argsort(mask_density)
seg_label = seg_label[orders] seg_label = seg_label[orders]
cate_label = cate_label[orders] cate_label = cate_label[orders]
...@@ -63,25 +60,13 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir): ...@@ -63,25 +60,13 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
cur_mask = seg_label[idx, :,:] cur_mask = seg_label[idx, :,:]
cur_mask = mmcv.imresize(cur_mask, (w, h)) cur_mask = mmcv.imresize(cur_mask, (w, h))
cur_mask = (cur_mask > 0.5).astype(np.uint8) cur_mask = (cur_mask > 0.5).astype(np.uint8)
if cur_mask.sum() == 0: if cur_mask.sum() == 0:
continue continue
color_mask = np.random.randint( color_mask = np.random.randint(
0, 256, (1, 3), dtype=np.uint8) 0, 256, (1, 3), dtype=np.uint8)
cur_mask_bool = cur_mask.astype(np.bool) cur_mask_bool = cur_mask.astype(np.bool)
seg_show[cur_mask_bool] = img_show[cur_mask_bool] * 0.5 + color_mask * 0.5 seg_show[cur_mask_bool] = img_show[cur_mask_bool] * 0.5 + color_mask * 0.5
for idx in range(num_mask):
idx = -(idx+1)
cur_mask = seg_label[idx, :, :]
cur_mask = mmcv.imresize(cur_mask, (w, h))
cur_mask = (cur_mask > 0.5).astype(np.uint8)
if cur_mask.sum() == 0:
continue
cur_cate = cate_label[idx] cur_cate = cate_label[idx]
cur_score = cate_score[idx] cur_score = cate_score[idx]
...@@ -92,7 +77,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir): ...@@ -92,7 +77,6 @@ def vis_seg(data, result, img_norm_cfg, data_id, colors, score_thr, save_dir):
vis_pos = (max(int(center_x) - 10, 0), int(center_y)) vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(seg_show, label_text, vis_pos, cv2.putText(seg_show, label_text, vis_pos,
cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255)) # green cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255)) # green
mmcv.imwrite(seg_show, '{}/{}.jpg'.format(save_dir, data_id)) mmcv.imwrite(seg_show, '{}/{}.jpg'.format(save_dir, data_id))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment