from abc import ABCMeta, abstractmethod import torch import torch.nn as nn class BaseDetector(nn.Module): """Base class for detectors""" __metaclass__ = ABCMeta def __init__(self): super(BaseDetector, self).__init__() @abstractmethod def init_weights(self): pass @abstractmethod def extract_feat(self, imgs): pass def extract_feats(self, imgs): if isinstance(imgs, torch.Tensor): return self.extract_feat(imgs) elif isinstance(imgs, list): for img in imgs: yield self.extract_feat(img) @abstractmethod def forward_train(self, imgs, img_metas, **kwargs): pass @abstractmethod def simple_test(self, img, img_meta, **kwargs): pass @abstractmethod def aug_test(self, imgs, img_metas, **kwargs): pass def forward_test(self, imgs, img_metas, **kwargs): for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError('{} must be a list, but got {}'.format( name, type(var))) num_augs = len(imgs) if num_augs != len(img_metas): raise ValueError( 'num of augmentations ({}) != num of image meta ({})'.format( len(imgs), len(img_metas))) # TODO: remove the restriction of imgs_per_gpu == 1 when prepared imgs_per_gpu = imgs[0].size(0) assert imgs_per_gpu == 1 if num_augs == 1: return self.simple_test(imgs[0], img_metas[0], **kwargs) else: return self.aug_test(imgs, img_metas, **kwargs) def forward(self, img, img_meta, return_loss=True, **kwargs): if return_loss: return self.forward_train(img, img_meta, **kwargs) else: return self.forward_test(img, img_meta, **kwargs)