base.py 5.87 KB
Newer Older
1
import logging
Kai Chen's avatar
Kai Chen committed
2
3
from abc import ABCMeta, abstractmethod

4
5
import mmcv
import numpy as np
Kai Chen's avatar
Kai Chen committed
6
import pycocotools.mask as maskUtils
7
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
8

9
from mmdet.core import auto_fp16, get_classes, tensor2imgs
10

Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18

class BaseDetector(nn.Module):
    """Base class for detectors"""

    __metaclass__ = ABCMeta

    def __init__(self):
        super(BaseDetector, self).__init__()
Cao Yuhang's avatar
Cao Yuhang committed
19
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
20

Kai Chen's avatar
Kai Chen committed
21
22
23
24
    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

myownskyW7's avatar
myownskyW7 committed
25
26
27
28
    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

Kai Chen's avatar
Kai Chen committed
29
30
31
32
33
34
35
36
    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

Kai Chen's avatar
Kai Chen committed
37
38
39
40
41
    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
Kai Chen's avatar
Kai Chen committed
42
43
44
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)
Kai Chen's avatar
Kai Chen committed
45
46
47

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
Jon Crall's avatar
Jon Crall committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        """
        Args:
            img (list[Tensor]): list of tensors of shape (1, C, H, W).
                Typically these should be mean centered and std scaled.

            img_metas (list[dict]): list of image info dict where each dict
                has:
                'img_shape', 'scale_factor', 'flip', and my also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

             **kwargs: specific to concrete implementation
        """
Kai Chen's avatar
Kai Chen committed
62
63
64
65
66
67
68
69
70
71
        pass

    @abstractmethod
    def simple_test(self, img, img_meta, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        pass

72
73
74
75
76
    def init_weights(self, pretrained=None):
        if pretrained is not None:
            logger = logging.getLogger()
            logger.info('load model from: {}'.format(pretrained))

Kai Chen's avatar
Kai Chen committed
77
    def forward_test(self, imgs, img_metas, **kwargs):
78
79
80
81
82
83
84
85
86
        """
        Args:
            imgs (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch.
            img_meta (List[List[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
        """
Kai Chen's avatar
Kai Chen committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

Cao Yuhang's avatar
Cao Yuhang committed
106
    @auto_fp16(apply_to=('img', ))
Kai Chen's avatar
Kai Chen committed
107
    def forward(self, img, img_meta, return_loss=True, **kwargs):
108
109
110
111
112
113
114
115
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
        When `return_loss=False`, img and img_meta are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=True`, img and img_meta
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.
        """
Kai Chen's avatar
Kai Chen committed
116
117
118
119
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
120

121
    def show_result(self, data, result, dataset=None, score_thr=0.3):
Kai Chen's avatar
Kai Chen committed
122
123
124
125
126
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None

127
128
        img_tensor = data['img'][0]
        img_metas = data['img_meta'][0].data[0]
129
        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
130
131
        assert len(imgs) == len(img_metas)

132
133
134
        if dataset is None:
            class_names = self.CLASSES
        elif isinstance(dataset, str):
135
            class_names = get_classes(dataset)
136
        elif isinstance(dataset, (list, tuple)):
137
138
            class_names = dataset
        else:
139
140
141
            raise TypeError(
                'dataset must be a valid dataset name or a sequence'
                ' of class names, not {}'.format(type(dataset)))
142
143
144
145

        for img, img_meta in zip(imgs, img_metas):
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
Kai Chen's avatar
Kai Chen committed
146
147
148
149
150
151
152
153
154
155
156
157

            bboxes = np.vstack(bbox_result)
            # draw segmentation masks
            if segm_result is not None:
                segms = mmcv.concat_list(segm_result)
                inds = np.where(bboxes[:, -1] > score_thr)[0]
                for i in inds:
                    color_mask = np.random.randint(
                        0, 256, (1, 3), dtype=np.uint8)
                    mask = maskUtils.decode(segms[i]).astype(np.bool)
                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            # draw bounding boxes
158
159
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
Kai Chen's avatar
Kai Chen committed
160
                for i, bbox in enumerate(bbox_result)
161
162
163
164
165
166
167
168
            ]
            labels = np.concatenate(labels)
            mmcv.imshow_det_bboxes(
                img_show,
                bboxes,
                labels,
                class_names=class_names,
                score_thr=score_thr)