bbox_head.py 4.25 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
import torch.nn as nn
import torch.nn.functional as F

Kai Chen's avatar
Kai Chen committed
4
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
Kai Chen's avatar
Kai Chen committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                        weighted_cross_entropy, weighted_smoothl1, accuracy)


class BBoxHead(nn.Module):
    """Simplest RoI head, with only two fc layers for classification and
    regression respectively"""

    def __init__(self,
                 with_avg_pool=False,
                 with_cls=True,
                 with_reg=True,
                 roi_feat_size=7,
                 in_channels=256,
                 num_classes=81,
                 target_means=[0., 0., 0., 0.],
                 target_stds=[0.1, 0.1, 0.2, 0.2],
                 reg_class_agnostic=False):
        super(BBoxHead, self).__init__()
        assert with_cls or with_reg
        self.with_avg_pool = with_avg_pool
        self.with_cls = with_cls
        self.with_reg = with_reg
        self.roi_feat_size = roi_feat_size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.target_means = target_means
        self.target_stds = target_stds
        self.reg_class_agnostic = reg_class_agnostic

        in_channels = self.in_channels
        if self.with_avg_pool:
            self.avg_pool = nn.AvgPool2d(roi_feat_size)
        else:
            in_channels *= (self.roi_feat_size * self.roi_feat_size)
        if self.with_cls:
            self.fc_cls = nn.Linear(in_channels, num_classes)
        if self.with_reg:
            out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
            self.fc_reg = nn.Linear(in_channels, out_dim_reg)
        self.debug_imgs = None

    def init_weights(self):
        if self.with_cls:
            nn.init.normal_(self.fc_cls.weight, 0, 0.01)
            nn.init.constant_(self.fc_cls.bias, 0)
        if self.with_reg:
            nn.init.normal_(self.fc_reg.weight, 0, 0.001)
            nn.init.constant_(self.fc_reg.bias, 0)

    def forward(self, x):
        if self.with_avg_pool:
            x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        cls_score = self.fc_cls(x) if self.with_cls else None
        bbox_pred = self.fc_reg(x) if self.with_reg else None
        return cls_score, bbox_pred

pangjm's avatar
pangjm committed
62
    def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes,
Kai Chen's avatar
Kai Chen committed
63
                        pos_gt_labels, rcnn_train_cfg):
Kai Chen's avatar
Kai Chen committed
64
65
66
67
68
69
70
71
        reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes
        cls_reg_targets = bbox_target(
            pos_proposals,
            neg_proposals,
            pos_gt_bboxes,
            pos_gt_labels,
            rcnn_train_cfg,
            reg_num_classes,
pangjm's avatar
pangjm committed
72
73
            target_means=self.target_means,
            target_stds=self.target_stds)
Kai Chen's avatar
Kai Chen committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        return cls_reg_targets

    def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
             bbox_weights):
        losses = dict()
        if cls_score is not None:
            losses['loss_cls'] = weighted_cross_entropy(
                cls_score, labels, label_weights)
            losses['acc'] = accuracy(cls_score, labels)
        if bbox_pred is not None:
            losses['loss_reg'] = weighted_smoothl1(
                bbox_pred,
                bbox_targets,
                bbox_weights,
Kai Chen's avatar
Kai Chen committed
88
                avg_factor=bbox_targets.size(0))
Kai Chen's avatar
Kai Chen committed
89
90
91
92
93
94
95
        return losses

    def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
pangjm's avatar
pangjm committed
96
                       scale_factor,
Kai Chen's avatar
Kai Chen committed
97
98
99
100
101
102
103
                       rescale=False,
                       nms_cfg=None):
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))
        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        if bbox_pred is not None:
Kai Chen's avatar
Kai Chen committed
104
105
            bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
                                self.target_stds, img_shape)
Kai Chen's avatar
Kai Chen committed
106
107
108
109
110
        else:
            bboxes = rois[:, 1:]
            # TODO: add clip here

        if rescale:
111
            bboxes /= scale_factor
Kai Chen's avatar
Kai Chen committed
112
113
114
115
116
117
118
119
120

        if nms_cfg is None:
            return bboxes, scores
        else:
            det_bboxes, det_labels = multiclass_nms(
                bboxes, scores, nms_cfg.score_thr, nms_cfg.nms_thr,
                nms_cfg.max_per_img)

            return det_bboxes, det_labels