base.py 6.75 KB
Newer Older
1
import logging
Kai Chen's avatar
Kai Chen committed
2
3
from abc import ABCMeta, abstractmethod

4
5
import mmcv
import numpy as np
Kai Chen's avatar
Kai Chen committed
6
import pycocotools.mask as maskUtils
7
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
8

9
from mmdet.core import auto_fp16, get_classes, tensor2imgs
10

Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18

class BaseDetector(nn.Module):
    """Base class for detectors"""

    __metaclass__ = ABCMeta

    def __init__(self):
        super(BaseDetector, self).__init__()
Cao Yuhang's avatar
Cao Yuhang committed
19
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
20

Kai Chen's avatar
Kai Chen committed
21
22
23
24
    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

myownskyW7's avatar
myownskyW7 committed
25
26
27
28
    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

Kai Chen's avatar
Kai Chen committed
29
30
31
32
33
34
35
36
    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

Kai Chen's avatar
Kai Chen committed
37
38
39
40
41
    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
Kai Chen's avatar
Kai Chen committed
42
43
44
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)
Kai Chen's avatar
Kai Chen committed
45
46
47

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
Jon Crall's avatar
Jon Crall committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        """
        Args:
            img (list[Tensor]): list of tensors of shape (1, C, H, W).
                Typically these should be mean centered and std scaled.

            img_metas (list[dict]): list of image info dict where each dict
                has:
                'img_shape', 'scale_factor', 'flip', and my also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

             **kwargs: specific to concrete implementation
        """
Kai Chen's avatar
Kai Chen committed
62
63
        pass

64
65
66
67
    @abstractmethod
    async def async_simple_test(self, img, img_meta, **kwargs):
        pass

Kai Chen's avatar
Kai Chen committed
68
69
70
71
72
73
74
75
    @abstractmethod
    def simple_test(self, img, img_meta, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        pass

76
77
78
79
80
    def init_weights(self, pretrained=None):
        if pretrained is not None:
            logger = logging.getLogger()
            logger.info('load model from: {}'.format(pretrained))

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    async def aforward_test(self, *, img, img_meta, **kwargs):
        for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(img)
        if num_augs != len(img_meta):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(img), len(img_meta)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = img[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return await self.async_simple_test(img[0], img_meta[0], **kwargs)
        else:
            raise NotImplementedError

Kai Chen's avatar
Kai Chen committed
101
    def forward_test(self, imgs, img_metas, **kwargs):
102
103
104
105
106
107
108
109
110
        """
        Args:
            imgs (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch.
            img_meta (List[List[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
        """
Kai Chen's avatar
Kai Chen committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

Cao Yuhang's avatar
Cao Yuhang committed
130
    @auto_fp16(apply_to=('img', ))
Kai Chen's avatar
Kai Chen committed
131
    def forward(self, img, img_meta, return_loss=True, **kwargs):
132
133
134
135
136
137
138
139
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
        When `return_loss=False`, img and img_meta are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=True`, img and img_meta
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.
        """
Kai Chen's avatar
Kai Chen committed
140
141
142
143
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
144

145
    def show_result(self, data, result, dataset=None, score_thr=0.3):
Kai Chen's avatar
Kai Chen committed
146
147
148
149
150
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None

151
152
        img_tensor = data['img'][0]
        img_metas = data['img_meta'][0].data[0]
153
        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
154
155
        assert len(imgs) == len(img_metas)

156
157
158
        if dataset is None:
            class_names = self.CLASSES
        elif isinstance(dataset, str):
159
            class_names = get_classes(dataset)
160
        elif isinstance(dataset, (list, tuple)):
161
162
            class_names = dataset
        else:
163
164
165
            raise TypeError(
                'dataset must be a valid dataset name or a sequence'
                ' of class names, not {}'.format(type(dataset)))
166
167
168
169

        for img, img_meta in zip(imgs, img_metas):
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
Kai Chen's avatar
Kai Chen committed
170
171
172
173
174
175
176
177
178
179
180
181

            bboxes = np.vstack(bbox_result)
            # draw segmentation masks
            if segm_result is not None:
                segms = mmcv.concat_list(segm_result)
                inds = np.where(bboxes[:, -1] > score_thr)[0]
                for i in inds:
                    color_mask = np.random.randint(
                        0, 256, (1, 3), dtype=np.uint8)
                    mask = maskUtils.decode(segms[i]).astype(np.bool)
                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            # draw bounding boxes
182
183
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
Kai Chen's avatar
Kai Chen committed
184
                for i, bbox in enumerate(bbox_result)
185
186
187
188
189
190
191
192
            ]
            labels = np.concatenate(labels)
            mmcv.imshow_det_bboxes(
                img_show,
                bboxes,
                labels,
                class_names=class_names,
                score_thr=score_thr)