base.py 5 KB
Newer Older
1
import logging
Kai Chen's avatar
Kai Chen committed
2
3
from abc import ABCMeta, abstractmethod

4
5
import mmcv
import numpy as np
Kai Chen's avatar
Kai Chen committed
6
import pycocotools.mask as maskUtils
7
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
8

9
from mmdet.core import auto_fp16, get_classes, tensor2imgs
10

Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18

class BaseDetector(nn.Module):
    """Base class for detectors"""

    __metaclass__ = ABCMeta

    def __init__(self):
        super(BaseDetector, self).__init__()
Cao Yuhang's avatar
Cao Yuhang committed
19
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
20

Kai Chen's avatar
Kai Chen committed
21
22
23
24
    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

myownskyW7's avatar
myownskyW7 committed
25
26
27
28
    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

Kai Chen's avatar
Kai Chen committed
29
30
31
32
33
34
35
36
    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

Kai Chen's avatar
Kai Chen committed
37
38
39
40
41
    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
Kai Chen's avatar
Kai Chen committed
42
43
44
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)
Kai Chen's avatar
Kai Chen committed
45
46
47

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
Jon Crall's avatar
Jon Crall committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        """
        Args:
            img (list[Tensor]): list of tensors of shape (1, C, H, W).
                Typically these should be mean centered and std scaled.

            img_metas (list[dict]): list of image info dict where each dict
                has:
                'img_shape', 'scale_factor', 'flip', and my also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

             **kwargs: specific to concrete implementation
        """
Kai Chen's avatar
Kai Chen committed
62
63
64
65
66
67
68
69
70
71
        pass

    @abstractmethod
    def simple_test(self, img, img_meta, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        pass

72
73
74
75
76
    def init_weights(self, pretrained=None):
        if pretrained is not None:
            logger = logging.getLogger()
            logger.info('load model from: {}'.format(pretrained))

Kai Chen's avatar
Kai Chen committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    def forward_test(self, imgs, img_metas, **kwargs):
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

Cao Yuhang's avatar
Cao Yuhang committed
97
    @auto_fp16(apply_to=('img', ))
Kai Chen's avatar
Kai Chen committed
98
99
100
101
102
    def forward(self, img, img_meta, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
103

104
    def show_result(self, data, result, dataset=None, score_thr=0.3):
Kai Chen's avatar
Kai Chen committed
105
106
107
108
109
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None

110
111
        img_tensor = data['img'][0]
        img_metas = data['img_meta'][0].data[0]
112
        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
113
114
        assert len(imgs) == len(img_metas)

115
116
117
        if dataset is None:
            class_names = self.CLASSES
        elif isinstance(dataset, str):
118
            class_names = get_classes(dataset)
119
        elif isinstance(dataset, (list, tuple)):
120
121
            class_names = dataset
        else:
122
123
124
            raise TypeError(
                'dataset must be a valid dataset name or a sequence'
                ' of class names, not {}'.format(type(dataset)))
125
126
127
128

        for img, img_meta in zip(imgs, img_metas):
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
Kai Chen's avatar
Kai Chen committed
129
130
131
132
133
134
135
136
137
138
139
140

            bboxes = np.vstack(bbox_result)
            # draw segmentation masks
            if segm_result is not None:
                segms = mmcv.concat_list(segm_result)
                inds = np.where(bboxes[:, -1] > score_thr)[0]
                for i in inds:
                    color_mask = np.random.randint(
                        0, 256, (1, 3), dtype=np.uint8)
                    mask = maskUtils.decode(segms[i]).astype(np.bool)
                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            # draw bounding boxes
141
142
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
Kai Chen's avatar
Kai Chen committed
143
                for i, bbox in enumerate(bbox_result)
144
145
146
147
148
149
150
151
            ]
            labels = np.concatenate(labels)
            mmcv.imshow_det_bboxes(
                img_show,
                bboxes,
                labels,
                class_names=class_names,
                score_thr=score_thr)