from abc import ABCMeta, abstractmethod import mmcv import numpy as np import pycocotools.mask as maskUtils import torch.nn as nn from mmdet.core import auto_fp16, get_classes, tensor2imgs from mmdet.utils import print_log class BaseDetector(nn.Module, metaclass=ABCMeta): """Base class for detectors""" def __init__(self): super(BaseDetector, self).__init__() self.fp16_enabled = False @property def with_neck(self): return hasattr(self, 'neck') and self.neck is not None @property def with_shared_head(self): return hasattr(self, 'shared_head') and self.shared_head is not None @property def with_bbox(self): return hasattr(self, 'bbox_head') and self.bbox_head is not None @property def with_mask(self): return hasattr(self, 'mask_head') and self.mask_head is not None @abstractmethod def extract_feat(self, imgs): pass def extract_feats(self, imgs): assert isinstance(imgs, list) for img in imgs: yield self.extract_feat(img) @abstractmethod def forward_train(self, imgs, img_metas, **kwargs): """ Args: img (list[Tensor]): list of tensors of shape (1, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and my also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. **kwargs: specific to concrete implementation """ pass async def async_simple_test(self, img, img_meta, **kwargs): raise NotImplementedError @abstractmethod def simple_test(self, img, img_meta, **kwargs): pass @abstractmethod def aug_test(self, imgs, img_metas, **kwargs): pass def init_weights(self, pretrained=None): if pretrained is not None: print_log('load model from: {}'.format(pretrained), logger='root') async def aforward_test(self, *, img, img_meta, **kwargs): for var, name in [(img, 'img'), (img_meta, 'img_meta')]: if not isinstance(var, list): raise TypeError('{} must be a list, but got {}'.format( name, type(var))) num_augs = len(img) if num_augs != len(img_meta): raise ValueError( 'num of augmentations ({}) != num of image meta ({})'.format( len(img), len(img_meta))) # TODO: remove the restriction of imgs_per_gpu == 1 when prepared imgs_per_gpu = img[0].size(0) assert imgs_per_gpu == 1 if num_augs == 1: return await self.async_simple_test(img[0], img_meta[0], **kwargs) else: raise NotImplementedError def forward_test(self, imgs, img_metas, **kwargs): """ Args: imgs (List[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains all images in the batch. img_meta (List[List[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch """ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError('{} must be a list, but got {}'.format( name, type(var))) num_augs = len(imgs) if num_augs != len(img_metas): raise ValueError( 'num of augmentations ({}) != num of image meta ({})'.format( len(imgs), len(img_metas))) # TODO: remove the restriction of imgs_per_gpu == 1 when prepared imgs_per_gpu = imgs[0].size(0) assert imgs_per_gpu == 1 if num_augs == 1: return self.simple_test(imgs[0], img_metas[0], **kwargs) else: return self.aug_test(imgs, img_metas, **kwargs) @auto_fp16(apply_to=('img', )) def forward(self, img, img_meta, return_loss=True, **kwargs): """ Calls either forward_train or forward_test depending on whether return_loss=True. Note this setting will change the expected inputs. When `return_loss=True`, img and img_meta are single-nested (i.e. Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta should be double nested (i.e. List[Tensor], List[List[dict]]), with the outer list indicating test time augmentations. """ if return_loss: return self.forward_train(img, img_meta, **kwargs) else: return self.forward_test(img, img_meta, **kwargs) def show_result(self, data, result, dataset=None, score_thr=0.3): if isinstance(result, tuple): bbox_result, segm_result = result else: bbox_result, segm_result = result, None img_tensor = data['img'][0] img_metas = data['img_meta'][0].data[0] imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) assert len(imgs) == len(img_metas) if dataset is None: class_names = self.CLASSES elif isinstance(dataset, str): class_names = get_classes(dataset) elif isinstance(dataset, (list, tuple)): class_names = dataset else: raise TypeError( 'dataset must be a valid dataset name or a sequence' ' of class names, not {}'.format(type(dataset))) for img, img_meta in zip(imgs, img_metas): h, w, _ = img_meta['img_shape'] img_show = img[:h, :w, :] bboxes = np.vstack(bbox_result) # draw segmentation masks if segm_result is not None: segms = mmcv.concat_list(segm_result) inds = np.where(bboxes[:, -1] > score_thr)[0] for i in inds: color_mask = np.random.randint( 0, 256, (1, 3), dtype=np.uint8) mask = maskUtils.decode(segms[i]).astype(np.bool) img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5 # draw bounding boxes labels = [ np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result) ] labels = np.concatenate(labels) mmcv.imshow_det_bboxes( img_show, bboxes, labels, class_names=class_names, score_thr=score_thr)