".github/vscode:/vscode.git/clone" did not exist on "2b4ca6cf36c0dc31fdae7046433f4341f171026f"
Unverified Commit 7facc34f authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Save class names in checkpoints and update the high-level inference APIs (#645)

* update the high-level inference api

* save classes in meta data and use it for visualization
parent 8d38fd8c
......@@ -62,28 +62,23 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
Here is an example of building the model and test given images.
```python
import mmcv
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
from mmdet.apis import inference_detector, show_result
from mmdet.apis import init_detector, inference_detector, show_result
cfg = mmcv.Config.fromfile('configs/faster_rcnn_r50_fpn_1x.py')
cfg.model.pretrained = None
config_file = 'configs/faster_rcnn_r50_fpn_1x.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth'
# construct the model and load checkpoint
model = build_detector(cfg.model, test_cfg=cfg.test_cfg)
_ = load_checkpoint(model, 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file)
# test a single image
img = mmcv.imread('test.jpg')
result = inference_detector(model, img, cfg)
show_result(img, result)
# test a single image and show the results
img = 'test.jpg' # or img = mmcv.imread(img), which will only load it once
result = inference_detector(model, img)
show_result(img, result, model.CLASSES)
# test a list of images
# test a list of images and write the results to image files
imgs = ['test1.jpg', 'test2.jpg']
for i, result in enumerate(inference_detector(model, imgs, cfg, device='cuda:0')):
print(i, imgs[i])
show_result(imgs[i], result)
for i, result in enumerate(inference_detector(model, imgs, device='cuda:0')):
show_result(imgs[i], result, model.CLASSES, out_file='result_{}.jpg'.format(i))
```
......
from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector
from .inference import inference_detector, show_result
from .inference import init_detector, inference_detector, show_result
__all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector', 'show_result'
'init_detector', 'inference_detector', 'show_result'
]
import warnings
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
import torch
from mmcv.runner import load_checkpoint
from mmdet.core import get_classes
from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform
from mmdet.models import build_detector
def init_detector(config, checkpoint=None, device='cuda:0'):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
config.model.pretrained = None
model = build_detector(config.model, test_cfg=config.test_cfg)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['classes']
else:
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_detector(model, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
If imgs is a str, a generator will be returned, otherwise return the
detection results directly.
"""
cfg = model.cfg
img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
device = next(model.parameters()).device # model device
if not isinstance(imgs, list):
return _inference_single(model, imgs, img_transform, device)
else:
return _inference_generator(model, imgs, img_transform, device)
def _prepare_data(img, img_transform, cfg, device):
......@@ -26,34 +86,34 @@ def _prepare_data(img, img_transform, cfg, device):
return dict(img=[img], img_meta=[img_meta])
def _inference_single(model, img, img_transform, cfg, device):
def _inference_single(model, img, img_transform, device):
img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device)
data = _prepare_data(img, img_transform, model.cfg, device)
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
def _inference_generator(model, imgs, img_transform, cfg, device):
def _inference_generator(model, imgs, img_transform, device):
for img in imgs:
yield _inference_single(model, img, img_transform, cfg, device)
yield _inference_single(model, img, img_transform, device)
def inference_detector(model, imgs, cfg, device='cuda:0'):
img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
model = model.to(device)
model.eval()
if not isinstance(imgs, list):
return _inference_single(model, imgs, img_transform, cfg, device)
else:
return _inference_generator(model, imgs, img_transform, cfg, device)
# TODO: merge this method with the one in BaseDetector
def show_result(img, result, class_names, score_thr=0.3, out_file=None):
"""Visualize the detection results on the image.
def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None):
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
"""
assert isinstance(class_names, (tuple, list))
img = mmcv.imread(img)
class_names = get_classes(dataset)
if isinstance(result, tuple):
bbox_result, segm_result = result
else:
......
......@@ -89,7 +89,7 @@ class BaseDetector(nn.Module):
data,
result,
img_norm_cfg,
dataset='coco',
dataset=None,
score_thr=0.3):
if isinstance(result, tuple):
bbox_result, segm_result = result
......@@ -101,9 +101,11 @@ class BaseDetector(nn.Module):
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
if isinstance(dataset, str):
if dataset is None:
class_names = self.CLASSES
elif isinstance(dataset, str):
class_names = get_classes(dataset)
elif isinstance(dataset, (list, tuple)) or dataset is None:
elif isinstance(dataset, (list, tuple)):
class_names = dataset
else:
raise TypeError(
......
......@@ -26,10 +26,7 @@ def single_gpu_test(model, data_loader, show=False):
results.append(result)
if show:
model.module.show_result(data,
result,
dataset.img_norm_cfg,
dataset=dataset.CLASSES)
model.module.show_result(data, result, dataset.img_norm_cfg)
batch_size = data['img'][0].size(0)
for _ in range(batch_size):
......
......@@ -52,10 +52,6 @@ def main():
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = args.gpus
if cfg.checkpoint_config is not None:
# save mmdet version in checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
......@@ -77,6 +73,14 @@ def main():
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = get_dataset(cfg.data.train)
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text,
classes=train_dataset.CLASSES)
# add an attribute for visualization convenience
model.CLASSES = train_dataset.CLASSES
train_detector(
model,
train_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment