Commit d6a724fb authored by Kai Chen's avatar Kai Chen
Browse files

update inference api

parent 2507eb6f
from .env import init_dist, get_root_logger, set_random_seed from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector from .train import train_detector
from .inference import inference_detector from .inference import inference_detector, show_result
__all__ = [ __all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', 'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector' 'inference_detector', 'show_result'
] ]
...@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device): ...@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device):
return dict(img=[img], img_meta=[img_meta]) return dict(img=[img], img_meta=[img_meta])
def inference_detector(model, imgs, cfg, device='cuda:0'): def _inference_single(model, img, img_transform, cfg, device):
img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device)
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
def _inference_generator(model, imgs, img_transform, cfg, device):
for img in imgs:
yield _inference_single(model, img, img_transform, cfg, device)
imgs = imgs if isinstance(imgs, list) else [imgs]
def inference_detector(model, imgs, cfg, device='cuda:0'):
img_transform = ImageTransform( img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg) size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
model = model.to(device) model = model.to(device)
model.eval() model.eval()
for img in imgs:
img = mmcv.imread(img) if not isinstance(imgs, list):
data = _prepare_data(img, img_transform, cfg, device) return _inference_single(model, imgs, img_transform, cfg, device)
with torch.no_grad(): else:
result = model(return_loss=False, rescale=True, **data) return _inference_generator(model, imgs, img_transform, cfg, device)
yield result
def show_result(img, result, dataset='coco', score_thr=0.3): def show_result(img, result, dataset='coco', score_thr=0.3):
...@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3): ...@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3):
] ]
labels = np.concatenate(labels) labels = np.concatenate(labels)
bboxes = np.vstack(result) bboxes = np.vstack(result)
img = mmcv.imread(img)
mmcv.imshow_det_bboxes( mmcv.imshow_det_bboxes(
img.copy(), img.copy(),
bboxes, bboxes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment