base.py 6.87 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
from abc import ABCMeta, abstractmethod

3
4
import mmcv
import numpy as np
Kai Chen's avatar
Kai Chen committed
5
import pycocotools.mask as maskUtils
6
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
7

8
from mmdet.core import auto_fp16, get_classes, tensor2imgs
Kai Chen's avatar
Kai Chen committed
9
from mmdet.utils import print_log
10

Kai Chen's avatar
Kai Chen committed
11

12
class BaseDetector(nn.Module, metaclass=ABCMeta):
Kai Chen's avatar
Kai Chen committed
13
14
15
16
    """Base class for detectors"""

    def __init__(self):
        super(BaseDetector, self).__init__()
Cao Yuhang's avatar
Cao Yuhang committed
17
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
18

Kai Chen's avatar
Kai Chen committed
19
20
21
22
    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

WXinlong's avatar
WXinlong committed
23
24
25
26
27
    @property
    def with_mask_feat_head(self):
        return hasattr(self, 'mask_feat_head') and \
            self.mask_feat_head is not None

myownskyW7's avatar
myownskyW7 committed
28
29
30
31
    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

Kai Chen's avatar
Kai Chen committed
32
33
34
35
36
37
38
39
    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

Kai Chen's avatar
Kai Chen committed
40
41
42
43
44
    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
Kai Chen's avatar
Kai Chen committed
45
46
47
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)
Kai Chen's avatar
Kai Chen committed
48
49
50

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
Jon Crall's avatar
Jon Crall committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        """
        Args:
            img (list[Tensor]): list of tensors of shape (1, C, H, W).
                Typically these should be mean centered and std scaled.

            img_metas (list[dict]): list of image info dict where each dict
                has:
                'img_shape', 'scale_factor', 'flip', and my also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

             **kwargs: specific to concrete implementation
        """
Kai Chen's avatar
Kai Chen committed
65
66
        pass

67
    async def async_simple_test(self, img, img_meta, **kwargs):
68
        raise NotImplementedError
69

Kai Chen's avatar
Kai Chen committed
70
71
72
73
74
75
76
77
    @abstractmethod
    def simple_test(self, img, img_meta, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        pass

78
79
    def init_weights(self, pretrained=None):
        if pretrained is not None:
Kai Chen's avatar
Kai Chen committed
80
            print_log('load model from: {}'.format(pretrained), logger='root')
81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    async def aforward_test(self, *, img, img_meta, **kwargs):
        for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(img)
        if num_augs != len(img_meta):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(img), len(img_meta)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = img[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return await self.async_simple_test(img[0], img_meta[0], **kwargs)
        else:
            raise NotImplementedError

Kai Chen's avatar
Kai Chen committed
102
    def forward_test(self, imgs, img_metas, **kwargs):
103
104
105
106
107
108
109
110
111
        """
        Args:
            imgs (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch.
            img_meta (List[List[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
        """
Kai Chen's avatar
Kai Chen committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

Cao Yuhang's avatar
Cao Yuhang committed
131
    @auto_fp16(apply_to=('img', ))
Kai Chen's avatar
Kai Chen committed
132
    def forward(self, img, img_meta, return_loss=True, **kwargs):
133
134
135
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
136
137
        When `return_loss=True`, img and img_meta are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
138
139
140
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.
        """
Kai Chen's avatar
Kai Chen committed
141
142
143
144
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
145

146
    def show_result(self, data, result, dataset=None, score_thr=0.3):
Kai Chen's avatar
Kai Chen committed
147
148
149
150
151
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None

152
153
        img_tensor = data['img'][0]
        img_metas = data['img_meta'][0].data[0]
154
        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
155
156
        assert len(imgs) == len(img_metas)

157
158
159
        if dataset is None:
            class_names = self.CLASSES
        elif isinstance(dataset, str):
160
            class_names = get_classes(dataset)
161
        elif isinstance(dataset, (list, tuple)):
162
163
            class_names = dataset
        else:
164
165
166
            raise TypeError(
                'dataset must be a valid dataset name or a sequence'
                ' of class names, not {}'.format(type(dataset)))
167
168
169
170

        for img, img_meta in zip(imgs, img_metas):
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
Kai Chen's avatar
Kai Chen committed
171
172
173
174
175
176
177
178
179
180
181
182

            bboxes = np.vstack(bbox_result)
            # draw segmentation masks
            if segm_result is not None:
                segms = mmcv.concat_list(segm_result)
                inds = np.where(bboxes[:, -1] > score_thr)[0]
                for i in inds:
                    color_mask = np.random.randint(
                        0, 256, (1, 3), dtype=np.uint8)
                    mask = maskUtils.decode(segms[i]).astype(np.bool)
                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            # draw bounding boxes
183
184
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
Kai Chen's avatar
Kai Chen committed
185
                for i, bbox in enumerate(bbox_result)
186
187
188
189
190
191
192
193
            ]
            labels = np.concatenate(labels)
            mmcv.imshow_det_bboxes(
                img_show,
                bboxes,
                labels,
                class_names=class_names,
                score_thr=score_thr)