base.py 6.73 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
from abc import ABCMeta, abstractmethod

3
4
import mmcv
import numpy as np
Kai Chen's avatar
Kai Chen committed
5
import pycocotools.mask as maskUtils
6
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
7

8
from mmdet.core import auto_fp16, get_classes, tensor2imgs
Kai Chen's avatar
Kai Chen committed
9
from mmdet.utils import print_log
10

Kai Chen's avatar
Kai Chen committed
11

12
class BaseDetector(nn.Module, metaclass=ABCMeta):
Kai Chen's avatar
Kai Chen committed
13
14
15
16
    """Base class for detectors"""

    def __init__(self):
        super(BaseDetector, self).__init__()
Cao Yuhang's avatar
Cao Yuhang committed
17
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
18

Kai Chen's avatar
Kai Chen committed
19
20
21
22
    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

myownskyW7's avatar
myownskyW7 committed
23
24
25
26
    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

Kai Chen's avatar
Kai Chen committed
27
28
29
30
31
32
33
34
    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

Kai Chen's avatar
Kai Chen committed
35
36
37
38
39
    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
Kai Chen's avatar
Kai Chen committed
40
41
42
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)
Kai Chen's avatar
Kai Chen committed
43
44
45

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
Jon Crall's avatar
Jon Crall committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        """
        Args:
            img (list[Tensor]): list of tensors of shape (1, C, H, W).
                Typically these should be mean centered and std scaled.

            img_metas (list[dict]): list of image info dict where each dict
                has:
                'img_shape', 'scale_factor', 'flip', and my also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

             **kwargs: specific to concrete implementation
        """
Kai Chen's avatar
Kai Chen committed
60
61
        pass

62
    async def async_simple_test(self, img, img_meta, **kwargs):
63
        raise NotImplementedError
64

Kai Chen's avatar
Kai Chen committed
65
66
67
68
69
70
71
72
    @abstractmethod
    def simple_test(self, img, img_meta, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        pass

73
74
    def init_weights(self, pretrained=None):
        if pretrained is not None:
Kai Chen's avatar
Kai Chen committed
75
            print_log('load model from: {}'.format(pretrained), logger='root')
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    async def aforward_test(self, *, img, img_meta, **kwargs):
        for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(img)
        if num_augs != len(img_meta):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(img), len(img_meta)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = img[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return await self.async_simple_test(img[0], img_meta[0], **kwargs)
        else:
            raise NotImplementedError

Kai Chen's avatar
Kai Chen committed
97
    def forward_test(self, imgs, img_metas, **kwargs):
98
99
100
101
102
103
104
105
106
        """
        Args:
            imgs (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch.
            img_meta (List[List[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
        """
Kai Chen's avatar
Kai Chen committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

Cao Yuhang's avatar
Cao Yuhang committed
126
    @auto_fp16(apply_to=('img', ))
Kai Chen's avatar
Kai Chen committed
127
    def forward(self, img, img_meta, return_loss=True, **kwargs):
128
129
130
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
131
132
        When `return_loss=True`, img and img_meta are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
133
134
135
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.
        """
Kai Chen's avatar
Kai Chen committed
136
137
138
139
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
140

141
    def show_result(self, data, result, dataset=None, score_thr=0.3):
Kai Chen's avatar
Kai Chen committed
142
143
144
145
146
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None

147
148
        img_tensor = data['img'][0]
        img_metas = data['img_meta'][0].data[0]
149
        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
150
151
        assert len(imgs) == len(img_metas)

152
153
154
        if dataset is None:
            class_names = self.CLASSES
        elif isinstance(dataset, str):
155
            class_names = get_classes(dataset)
156
        elif isinstance(dataset, (list, tuple)):
157
158
            class_names = dataset
        else:
159
160
161
            raise TypeError(
                'dataset must be a valid dataset name or a sequence'
                ' of class names, not {}'.format(type(dataset)))
162
163
164
165

        for img, img_meta in zip(imgs, img_metas):
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
Kai Chen's avatar
Kai Chen committed
166
167
168
169
170
171
172
173
174
175
176
177

            bboxes = np.vstack(bbox_result)
            # draw segmentation masks
            if segm_result is not None:
                segms = mmcv.concat_list(segm_result)
                inds = np.where(bboxes[:, -1] > score_thr)[0]
                for i in inds:
                    color_mask = np.random.randint(
                        0, 256, (1, 3), dtype=np.uint8)
                    mask = maskUtils.decode(segms[i]).astype(np.bool)
                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            # draw bounding boxes
178
179
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
Kai Chen's avatar
Kai Chen committed
180
                for i, bbox in enumerate(bbox_result)
181
182
183
184
185
186
187
188
            ]
            labels = np.concatenate(labels)
            mmcv.imshow_det_bboxes(
                img_show,
                bboxes,
                labels,
                class_names=class_names,
                score_thr=score_thr)