base.py 6.77 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
from abc import ABCMeta, abstractmethod

3
4
import mmcv
import numpy as np
Kai Chen's avatar
Kai Chen committed
5
import pycocotools.mask as maskUtils
6
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
7

8
from mmdet.core import auto_fp16, get_classes, tensor2imgs
9

Kai Chen's avatar
Kai Chen committed
10

11
class BaseDetector(nn.Module, metaclass=ABCMeta):
Kai Chen's avatar
Kai Chen committed
12
13
14
15
    """Base class for detectors"""

    def __init__(self):
        super(BaseDetector, self).__init__()
Cao Yuhang's avatar
Cao Yuhang committed
16
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
17

Kai Chen's avatar
Kai Chen committed
18
19
20
21
    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

myownskyW7's avatar
myownskyW7 committed
22
23
24
25
    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

Kai Chen's avatar
Kai Chen committed
26
27
28
29
30
31
32
33
    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

Kai Chen's avatar
Kai Chen committed
34
35
36
37
38
    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
Kai Chen's avatar
Kai Chen committed
39
40
41
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)
Kai Chen's avatar
Kai Chen committed
42
43
44

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
Jon Crall's avatar
Jon Crall committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        """
        Args:
            img (list[Tensor]): list of tensors of shape (1, C, H, W).
                Typically these should be mean centered and std scaled.

            img_metas (list[dict]): list of image info dict where each dict
                has:
                'img_shape', 'scale_factor', 'flip', and my also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

             **kwargs: specific to concrete implementation
        """
Kai Chen's avatar
Kai Chen committed
59
60
        pass

61
    async def async_simple_test(self, img, img_meta, **kwargs):
62
        raise NotImplementedError
63

Kai Chen's avatar
Kai Chen committed
64
65
66
67
68
69
70
71
    @abstractmethod
    def simple_test(self, img, img_meta, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        pass

72
73
    def init_weights(self, pretrained=None):
        if pretrained is not None:
74
75
            from mmdet.apis import get_root_logger
            logger = get_root_logger()
76
77
            logger.info('load model from: {}'.format(pretrained))

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    async def aforward_test(self, *, img, img_meta, **kwargs):
        for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(img)
        if num_augs != len(img_meta):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(img), len(img_meta)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = img[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return await self.async_simple_test(img[0], img_meta[0], **kwargs)
        else:
            raise NotImplementedError

Kai Chen's avatar
Kai Chen committed
98
    def forward_test(self, imgs, img_metas, **kwargs):
99
100
101
102
103
104
105
106
107
        """
        Args:
            imgs (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch.
            img_meta (List[List[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
        """
Kai Chen's avatar
Kai Chen committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

Cao Yuhang's avatar
Cao Yuhang committed
127
    @auto_fp16(apply_to=('img', ))
Kai Chen's avatar
Kai Chen committed
128
    def forward(self, img, img_meta, return_loss=True, **kwargs):
129
130
131
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
132
133
        When `return_loss=True`, img and img_meta are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
134
135
136
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.
        """
Kai Chen's avatar
Kai Chen committed
137
138
139
140
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
141

142
    def show_result(self, data, result, dataset=None, score_thr=0.3):
Kai Chen's avatar
Kai Chen committed
143
144
145
146
147
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None

148
149
        img_tensor = data['img'][0]
        img_metas = data['img_meta'][0].data[0]
150
        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
151
152
        assert len(imgs) == len(img_metas)

153
154
155
        if dataset is None:
            class_names = self.CLASSES
        elif isinstance(dataset, str):
156
            class_names = get_classes(dataset)
157
        elif isinstance(dataset, (list, tuple)):
158
159
            class_names = dataset
        else:
160
161
162
            raise TypeError(
                'dataset must be a valid dataset name or a sequence'
                ' of class names, not {}'.format(type(dataset)))
163
164
165
166

        for img, img_meta in zip(imgs, img_metas):
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
Kai Chen's avatar
Kai Chen committed
167
168
169
170
171
172
173
174
175
176
177
178

            bboxes = np.vstack(bbox_result)
            # draw segmentation masks
            if segm_result is not None:
                segms = mmcv.concat_list(segm_result)
                inds = np.where(bboxes[:, -1] > score_thr)[0]
                for i in inds:
                    color_mask = np.random.randint(
                        0, 256, (1, 3), dtype=np.uint8)
                    mask = maskUtils.decode(segms[i]).astype(np.bool)
                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            # draw bounding boxes
179
180
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
Kai Chen's avatar
Kai Chen committed
181
                for i, bbox in enumerate(bbox_result)
182
183
184
185
186
187
188
189
            ]
            labels = np.concatenate(labels)
            mmcv.imshow_det_bboxes(
                img_show,
                bboxes,
                labels,
                class_names=class_names,
                score_thr=score_thr)