base.py 4.4 KB
Newer Older
1
import logging
Kai Chen's avatar
Kai Chen committed
2
3
from abc import ABCMeta, abstractmethod

4
5
import mmcv
import numpy as np
Kai Chen's avatar
Kai Chen committed
6
import pycocotools.mask as maskUtils
7
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
8

9
from mmdet.core import auto_fp16, get_classes, tensor2imgs
10

Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18

class BaseDetector(nn.Module):
    """Base class for detectors"""

    __metaclass__ = ABCMeta

    def __init__(self):
        super(BaseDetector, self).__init__()
Cao Yuhang's avatar
Cao Yuhang committed
19
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
20

Kai Chen's avatar
Kai Chen committed
21
22
23
24
    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

myownskyW7's avatar
myownskyW7 committed
25
26
27
28
    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

Kai Chen's avatar
Kai Chen committed
29
30
31
32
33
34
35
36
    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

Kai Chen's avatar
Kai Chen committed
37
38
39
40
41
    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
Kai Chen's avatar
Kai Chen committed
42
43
44
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)
Kai Chen's avatar
Kai Chen committed
45
46
47
48
49
50
51
52
53
54
55
56
57

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
        pass

    @abstractmethod
    def simple_test(self, img, img_meta, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        pass

58
59
60
61
62
    def init_weights(self, pretrained=None):
        if pretrained is not None:
            logger = logging.getLogger()
            logger.info('load model from: {}'.format(pretrained))

Kai Chen's avatar
Kai Chen committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    def forward_test(self, imgs, img_metas, **kwargs):
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

Cao Yuhang's avatar
Cao Yuhang committed
83
    @auto_fp16(apply_to=('img', ))
Kai Chen's avatar
Kai Chen committed
84
85
86
87
88
    def forward(self, img, img_meta, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)
89

90
    def show_result(self, data, result, dataset=None, score_thr=0.3):
Kai Chen's avatar
Kai Chen committed
91
92
93
94
95
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None

96
97
        img_tensor = data['img'][0]
        img_metas = data['img_meta'][0].data[0]
98
        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
99
100
        assert len(imgs) == len(img_metas)

101
102
103
        if dataset is None:
            class_names = self.CLASSES
        elif isinstance(dataset, str):
104
            class_names = get_classes(dataset)
105
        elif isinstance(dataset, (list, tuple)):
106
107
            class_names = dataset
        else:
108
109
110
            raise TypeError(
                'dataset must be a valid dataset name or a sequence'
                ' of class names, not {}'.format(type(dataset)))
111
112
113
114

        for img, img_meta in zip(imgs, img_metas):
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
Kai Chen's avatar
Kai Chen committed
115
116
117
118
119
120
121
122
123
124
125
126

            bboxes = np.vstack(bbox_result)
            # draw segmentation masks
            if segm_result is not None:
                segms = mmcv.concat_list(segm_result)
                inds = np.where(bboxes[:, -1] > score_thr)[0]
                for i in inds:
                    color_mask = np.random.randint(
                        0, 256, (1, 3), dtype=np.uint8)
                    mask = maskUtils.decode(segms[i]).astype(np.bool)
                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            # draw bounding boxes
127
128
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
Kai Chen's avatar
Kai Chen committed
129
                for i, bbox in enumerate(bbox_result)
130
131
132
133
134
135
136
137
            ]
            labels = np.concatenate(labels)
            mmcv.imshow_det_bboxes(
                img_show,
                bboxes,
                labels,
                class_names=class_names,
                score_thr=score_thr)