Commit 830effcd authored by Kai Chen's avatar Kai Chen
Browse files

add default result visualization for base detector

parent 65c3ebca
......@@ -95,7 +95,7 @@ def get_classes(dataset):
if mmcv.is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_labels()')
labels = eval(alias2name[dataset] + '_classes()')
else:
raise ValueError('Unrecognized dataset: {}'.format(dataset))
else:
......
import logging
from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmdet.core import tensor2imgs, get_classes
class BaseDetector(nn.Module):
"""Base class for detectors"""
......@@ -66,3 +70,38 @@ class BaseDetector(nn.Module):
return self.forward_train(img, img_meta, **kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
def show_result(self,
data,
result,
img_norm_cfg,
dataset='coco',
score_thr=0.3):
img_tensor = data['img'][0]
img_metas = data['img_meta'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
if isinstance(dataset, str):
class_names = get_classes(dataset)
elif isinstance(dataset, list):
class_names = dataset
else:
raise TypeError('dataset must be a valid dataset name or a list'
' of class names, not {}'.format(type(dataset)))
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
mmcv.imshow_det_bboxes(
img_show,
bboxes,
labels,
class_names=class_names,
score_thr=score_thr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment