bbox_head.py 8.76 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
import torch
Kai Chen's avatar
Kai Chen committed
2
3
import torch.nn as nn
import torch.nn.functional as F
4
from torch.nn.modules.utils import _pair
Cao Yuhang's avatar
Cao Yuhang committed
5

6
7
from mmdet.core import (auto_fp16, bbox_target, delta2bbox, force_fp32,
                        multiclass_nms)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
8
from ..builder import build_loss
Kai Chen's avatar
Kai Chen committed
9
from ..losses import accuracy
Kai Chen's avatar
Kai Chen committed
10
from ..registry import HEADS
Kai Chen's avatar
Kai Chen committed
11
12


Kai Chen's avatar
Kai Chen committed
13
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
14
15
16
17
18
19
20
21
22
23
24
25
26
class BBoxHead(nn.Module):
    """Simplest RoI head, with only two fc layers for classification and
    regression respectively"""

    def __init__(self,
                 with_avg_pool=False,
                 with_cls=True,
                 with_reg=True,
                 roi_feat_size=7,
                 in_channels=256,
                 num_classes=81,
                 target_means=[0., 0., 0., 0.],
                 target_stds=[0.1, 0.1, 0.2, 0.2],
Jiangmiao Pang's avatar
Jiangmiao Pang committed
27
28
29
30
31
32
33
                 reg_class_agnostic=False,
                 loss_cls=dict(
                     type='CrossEntropyLoss',
                     use_sigmoid=False,
                     loss_weight=1.0),
                 loss_bbox=dict(
                     type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
Kai Chen's avatar
Kai Chen committed
34
35
36
37
38
        super(BBoxHead, self).__init__()
        assert with_cls or with_reg
        self.with_avg_pool = with_avg_pool
        self.with_cls = with_cls
        self.with_reg = with_reg
39
40
        self.roi_feat_size = _pair(roi_feat_size)
        self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
Kai Chen's avatar
Kai Chen committed
41
42
43
44
45
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.target_means = target_means
        self.target_stds = target_stds
        self.reg_class_agnostic = reg_class_agnostic
Cao Yuhang's avatar
Cao Yuhang committed
46
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
47

Jiangmiao Pang's avatar
Jiangmiao Pang committed
48
49
50
        self.loss_cls = build_loss(loss_cls)
        self.loss_bbox = build_loss(loss_bbox)

Kai Chen's avatar
Kai Chen committed
51
52
        in_channels = self.in_channels
        if self.with_avg_pool:
53
            self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
Kai Chen's avatar
Kai Chen committed
54
        else:
55
            in_channels *= self.roi_feat_area
Kai Chen's avatar
Kai Chen committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        if self.with_cls:
            self.fc_cls = nn.Linear(in_channels, num_classes)
        if self.with_reg:
            out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
            self.fc_reg = nn.Linear(in_channels, out_dim_reg)
        self.debug_imgs = None

    def init_weights(self):
        if self.with_cls:
            nn.init.normal_(self.fc_cls.weight, 0, 0.01)
            nn.init.constant_(self.fc_cls.bias, 0)
        if self.with_reg:
            nn.init.normal_(self.fc_reg.weight, 0, 0.001)
            nn.init.constant_(self.fc_reg.bias, 0)

Cao Yuhang's avatar
Cao Yuhang committed
71
    @auto_fp16()
Kai Chen's avatar
Kai Chen committed
72
73
74
75
76
77
78
79
    def forward(self, x):
        if self.with_avg_pool:
            x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        cls_score = self.fc_cls(x) if self.with_cls else None
        bbox_pred = self.fc_reg(x) if self.with_reg else None
        return cls_score, bbox_pred

Kai Chen's avatar
Kai Chen committed
80
81
82
83
84
85
86
    def get_target(self, sampling_results, gt_bboxes, gt_labels,
                   rcnn_train_cfg):
        pos_proposals = [res.pos_bboxes for res in sampling_results]
        neg_proposals = [res.neg_bboxes for res in sampling_results]
        pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
        pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
        reg_classes = 1 if self.reg_class_agnostic else self.num_classes
Kai Chen's avatar
Kai Chen committed
87
88
89
90
91
92
        cls_reg_targets = bbox_target(
            pos_proposals,
            neg_proposals,
            pos_gt_bboxes,
            pos_gt_labels,
            rcnn_train_cfg,
Kai Chen's avatar
Kai Chen committed
93
            reg_classes,
pangjm's avatar
pangjm committed
94
95
            target_means=self.target_means,
            target_stds=self.target_stds)
Kai Chen's avatar
Kai Chen committed
96
97
        return cls_reg_targets

Cao Yuhang's avatar
Cao Yuhang committed
98
    @force_fp32(apply_to=('cls_score', 'bbox_pred'))
Kai Chen's avatar
Kai Chen committed
99
100
101
102
103
104
105
    def loss(self,
             cls_score,
             bbox_pred,
             labels,
             label_weights,
             bbox_targets,
             bbox_weights,
106
             reduction_override=None):
Kai Chen's avatar
Kai Chen committed
107
108
        losses = dict()
        if cls_score is not None:
Kai Chen's avatar
Kai Chen committed
109
            avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
110
            losses['loss_cls'] = self.loss_cls(
111
112
113
114
115
                cls_score,
                labels,
                label_weights,
                avg_factor=avg_factor,
                reduction_override=reduction_override)
Kai Chen's avatar
Kai Chen committed
116
117
            losses['acc'] = accuracy(cls_score, labels)
        if bbox_pred is not None:
Cao Yuhang's avatar
Cao Yuhang committed
118
            pos_inds = labels > 0
119
            if self.reg_class_agnostic:
Cao Yuhang's avatar
Cao Yuhang committed
120
                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds]
121
122
            else:
                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
Cao Yuhang's avatar
Cao Yuhang committed
123
                                               4)[pos_inds, labels[pos_inds]]
Jiangmiao Pang's avatar
Jiangmiao Pang committed
124
            losses['loss_bbox'] = self.loss_bbox(
125
                pos_bbox_pred,
Cao Yuhang's avatar
Cao Yuhang committed
126
127
                bbox_targets[pos_inds],
                bbox_weights[pos_inds],
128
129
                avg_factor=bbox_targets.size(0),
                reduction_override=reduction_override)
Kai Chen's avatar
Kai Chen committed
130
131
        return losses

Cao Yuhang's avatar
Cao Yuhang committed
132
    @force_fp32(apply_to=('cls_score', 'bbox_pred'))
Kai Chen's avatar
Kai Chen committed
133
134
135
136
137
    def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
pangjm's avatar
pangjm committed
138
                       scale_factor,
Kai Chen's avatar
Kai Chen committed
139
                       rescale=False,
140
                       cfg=None):
Kai Chen's avatar
Kai Chen committed
141
142
143
144
145
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))
        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        if bbox_pred is not None:
Kai Chen's avatar
Kai Chen committed
146
147
            bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
                                self.target_stds, img_shape)
Kai Chen's avatar
Kai Chen committed
148
        else:
luxiin's avatar
luxiin committed
149
150
151
152
            bboxes = rois[:, 1:].clone()
            if img_shape is not None:
                bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
                bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
Kai Chen's avatar
Kai Chen committed
153
154

        if rescale:
155
            bboxes /= scale_factor
Kai Chen's avatar
Kai Chen committed
156

157
        if cfg is None:
Kai Chen's avatar
Kai Chen committed
158
159
            return bboxes, scores
        else:
Jiangmiao Pang's avatar
Jiangmiao Pang committed
160
161
162
            det_bboxes, det_labels = multiclass_nms(bboxes, scores,
                                                    cfg.score_thr, cfg.nms,
                                                    cfg.max_per_img)
Kai Chen's avatar
Kai Chen committed
163
164

            return det_bboxes, det_labels
Kai Chen's avatar
Kai Chen committed
165

Cao Yuhang's avatar
Cao Yuhang committed
166
    @force_fp32(apply_to=('bbox_preds', ))
Kai Chen's avatar
Kai Chen committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
        """Refine bboxes during training.

        Args:
            rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
                and bs is the sampled RoIs per image.
            labels (Tensor): Shape (n*bs, ).
            bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
            pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
                is a gt bbox.
            img_metas (list[dict]): Meta info of each image.

        Returns:
            list[Tensor]: Refined bboxes of each image in a mini-batch.
        """
        img_ids = rois[:, 0].long().unique(sorted=True)
        assert img_ids.numel() == len(img_metas)

        bboxes_list = []
        for i in range(len(img_metas)):
            inds = torch.nonzero(rois[:, 0] == i).squeeze()
            num_rois = inds.numel()

            bboxes_ = rois[inds, 1:]
            label_ = labels[inds]
            bbox_pred_ = bbox_preds[inds]
            img_meta_ = img_metas[i]
            pos_is_gts_ = pos_is_gts[i]

            bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
                                           img_meta_)
            # filter gt bboxes
            pos_keep = 1 - pos_is_gts_
            keep_inds = pos_is_gts_.new_ones(num_rois)
            keep_inds[:len(pos_is_gts_)] = pos_keep

            bboxes_list.append(bboxes[keep_inds])

        return bboxes_list

Cao Yuhang's avatar
Cao Yuhang committed
207
    @force_fp32(apply_to=('bbox_pred', ))
Kai Chen's avatar
Kai Chen committed
208
209
210
211
212
213
214
    def regress_by_class(self, rois, label, bbox_pred, img_meta):
        """Regress the bbox for the predicted class. Used in Cascade R-CNN.

        Args:
            rois (Tensor): shape (n, 4) or (n, 5)
            label (Tensor): shape (n, )
            bbox_pred (Tensor): shape (n, 4*(#class+1)) or (n, 4)
Kai Chen's avatar
Kai Chen committed
215
            img_meta (dict): Image meta info.
Kai Chen's avatar
Kai Chen committed
216
217

        Returns:
Kai Chen's avatar
Kai Chen committed
218
            Tensor: Regressed bboxes, the same shape as input rois.
Kai Chen's avatar
Kai Chen committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        """
        assert rois.size(1) == 4 or rois.size(1) == 5

        if not self.reg_class_agnostic:
            label = label * 4
            inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
            bbox_pred = torch.gather(bbox_pred, 1, inds)
        assert bbox_pred.size(1) == 4

        if rois.size(1) == 4:
            new_rois = delta2bbox(rois, bbox_pred, self.target_means,
                                  self.target_stds, img_meta['img_shape'])
        else:
            bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
                                self.target_stds, img_meta['img_shape'])
            new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)

        return new_rois