Unverified Commit 7facc34f authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Save class names in checkpoints and update the high-level inference APIs (#645)

* update the high-level inference api

* save classes in meta data and use it for visualization
parent 8d38fd8c
...@@ -62,28 +62,23 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \ ...@@ -62,28 +62,23 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
Here is an example of building the model and test given images. Here is an example of building the model and test given images.
```python ```python
import mmcv from mmdet.apis import init_detector, inference_detector, show_result
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
from mmdet.apis import inference_detector, show_result
cfg = mmcv.Config.fromfile('configs/faster_rcnn_r50_fpn_1x.py') config_file = 'configs/faster_rcnn_r50_fpn_1x.py'
cfg.model.pretrained = None checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth'
# construct the model and load checkpoint # build the model from a config file and a checkpoint file
model = build_detector(cfg.model, test_cfg=cfg.test_cfg) model = init_detector(config_file, checkpoint_file)
_ = load_checkpoint(model, 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')
# test a single image # test a single image and show the results
img = mmcv.imread('test.jpg') img = 'test.jpg' # or img = mmcv.imread(img), which will only load it once
result = inference_detector(model, img, cfg) result = inference_detector(model, img)
show_result(img, result) show_result(img, result, model.CLASSES)
# test a list of images # test a list of images and write the results to image files
imgs = ['test1.jpg', 'test2.jpg'] imgs = ['test1.jpg', 'test2.jpg']
for i, result in enumerate(inference_detector(model, imgs, cfg, device='cuda:0')): for i, result in enumerate(inference_detector(model, imgs, device='cuda:0')):
print(i, imgs[i]) show_result(imgs[i], result, model.CLASSES, out_file='result_{}.jpg'.format(i))
show_result(imgs[i], result)
``` ```
......
from .env import init_dist, get_root_logger, set_random_seed from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector from .train import train_detector
from .inference import inference_detector, show_result from .inference import init_detector, inference_detector, show_result
__all__ = [ __all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', 'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector', 'show_result' 'init_detector', 'inference_detector', 'show_result'
] ]
import warnings
import mmcv import mmcv
import numpy as np import numpy as np
import pycocotools.mask as maskUtils import pycocotools.mask as maskUtils
import torch import torch
from mmcv.runner import load_checkpoint
from mmdet.core import get_classes from mmdet.core import get_classes
from mmdet.datasets import to_tensor from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform from mmdet.datasets.transforms import ImageTransform
from mmdet.models import build_detector
def init_detector(config, checkpoint=None, device='cuda:0'):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
config.model.pretrained = None
model = build_detector(config.model, test_cfg=config.test_cfg)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['classes']
else:
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_detector(model, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
If imgs is a str, a generator will be returned, otherwise return the
detection results directly.
"""
cfg = model.cfg
img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
device = next(model.parameters()).device # model device
if not isinstance(imgs, list):
return _inference_single(model, imgs, img_transform, device)
else:
return _inference_generator(model, imgs, img_transform, device)
def _prepare_data(img, img_transform, cfg, device): def _prepare_data(img, img_transform, cfg, device):
...@@ -26,34 +86,34 @@ def _prepare_data(img, img_transform, cfg, device): ...@@ -26,34 +86,34 @@ def _prepare_data(img, img_transform, cfg, device):
return dict(img=[img], img_meta=[img_meta]) return dict(img=[img], img_meta=[img_meta])
def _inference_single(model, img, img_transform, cfg, device): def _inference_single(model, img, img_transform, device):
img = mmcv.imread(img) img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device) data = _prepare_data(img, img_transform, model.cfg, device)
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) result = model(return_loss=False, rescale=True, **data)
return result return result
def _inference_generator(model, imgs, img_transform, cfg, device): def _inference_generator(model, imgs, img_transform, device):
for img in imgs: for img in imgs:
yield _inference_single(model, img, img_transform, cfg, device) yield _inference_single(model, img, img_transform, device)
def inference_detector(model, imgs, cfg, device='cuda:0'):
img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
model = model.to(device)
model.eval()
if not isinstance(imgs, list):
return _inference_single(model, imgs, img_transform, cfg, device)
else:
return _inference_generator(model, imgs, img_transform, cfg, device)
# TODO: merge this method with the one in BaseDetector
def show_result(img, result, class_names, score_thr=0.3, out_file=None):
"""Visualize the detection results on the image.
def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None): Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
"""
assert isinstance(class_names, (tuple, list))
img = mmcv.imread(img) img = mmcv.imread(img)
class_names = get_classes(dataset)
if isinstance(result, tuple): if isinstance(result, tuple):
bbox_result, segm_result = result bbox_result, segm_result = result
else: else:
......
...@@ -89,7 +89,7 @@ class BaseDetector(nn.Module): ...@@ -89,7 +89,7 @@ class BaseDetector(nn.Module):
data, data,
result, result,
img_norm_cfg, img_norm_cfg,
dataset='coco', dataset=None,
score_thr=0.3): score_thr=0.3):
if isinstance(result, tuple): if isinstance(result, tuple):
bbox_result, segm_result = result bbox_result, segm_result = result
...@@ -101,9 +101,11 @@ class BaseDetector(nn.Module): ...@@ -101,9 +101,11 @@ class BaseDetector(nn.Module):
imgs = tensor2imgs(img_tensor, **img_norm_cfg) imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas) assert len(imgs) == len(img_metas)
if isinstance(dataset, str): if dataset is None:
class_names = self.CLASSES
elif isinstance(dataset, str):
class_names = get_classes(dataset) class_names = get_classes(dataset)
elif isinstance(dataset, (list, tuple)) or dataset is None: elif isinstance(dataset, (list, tuple)):
class_names = dataset class_names = dataset
else: else:
raise TypeError( raise TypeError(
......
...@@ -26,10 +26,7 @@ def single_gpu_test(model, data_loader, show=False): ...@@ -26,10 +26,7 @@ def single_gpu_test(model, data_loader, show=False):
results.append(result) results.append(result)
if show: if show:
model.module.show_result(data, model.module.show_result(data, result, dataset.img_norm_cfg)
result,
dataset.img_norm_cfg,
dataset=dataset.CLASSES)
batch_size = data['img'][0].size(0) batch_size = data['img'][0].size(0)
for _ in range(batch_size): for _ in range(batch_size):
......
...@@ -52,10 +52,6 @@ def main(): ...@@ -52,10 +52,6 @@ def main():
if args.resume_from is not None: if args.resume_from is not None:
cfg.resume_from = args.resume_from cfg.resume_from = args.resume_from
cfg.gpus = args.gpus cfg.gpus = args.gpus
if cfg.checkpoint_config is not None:
# save mmdet version in checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text)
# init distributed env first, since logger depends on the dist info. # init distributed env first, since logger depends on the dist info.
if args.launcher == 'none': if args.launcher == 'none':
...@@ -77,6 +73,14 @@ def main(): ...@@ -77,6 +73,14 @@ def main():
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = get_dataset(cfg.data.train) train_dataset = get_dataset(cfg.data.train)
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text,
classes=train_dataset.CLASSES)
# add an attribute for visualization convenience
model.CLASSES = train_dataset.CLASSES
train_detector( train_detector(
model, model,
train_dataset, train_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment