ssd_head.py 7.49 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
6
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init

Kai Chen's avatar
Kai Chen committed
7
from mmdet.core import AnchorGenerator, anchor_target, multi_apply
8
from .anchor_head import AnchorHead
Kai Chen's avatar
Kai Chen committed
9
from ..losses import smooth_l1_loss
Kai Chen's avatar
Kai Chen committed
10
from ..registry import HEADS
yhcao6's avatar
yhcao6 committed
11
12


Jiangmiao Pang's avatar
Jiangmiao Pang committed
13
# TODO: add loss evaluator for SSD
Kai Chen's avatar
Kai Chen committed
14
@HEADS.register_module
15
class SSDHead(AnchorHead):
yhcao6's avatar
yhcao6 committed
16
17
18
19

    def __init__(self,
                 input_size=300,
                 num_classes=81,
20
                 in_channels=(512, 1024, 512, 256, 256, 256),
yhcao6's avatar
yhcao6 committed
21
22
23
24
25
                 anchor_strides=(8, 16, 32, 64, 100, 300),
                 basesize_ratio_range=(0.1, 0.9),
                 anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
                 target_means=(.0, .0, .0, .0),
                 target_stds=(1.0, 1.0, 1.0, 1.0)):
26
27
        super(AnchorHead, self).__init__()
        self.input_size = input_size
yhcao6's avatar
yhcao6 committed
28
        self.num_classes = num_classes
29
        self.in_channels = in_channels
yhcao6's avatar
yhcao6 committed
30
        self.cls_out_channels = num_classes
31
        num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
yhcao6's avatar
yhcao6 committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        reg_convs = []
        cls_convs = []
        for i in range(len(in_channels)):
            reg_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * 4,
                    kernel_size=3,
                    padding=1))
            cls_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * num_classes,
                    kernel_size=3,
                    padding=1))
        self.reg_convs = nn.ModuleList(reg_convs)
        self.cls_convs = nn.ModuleList(cls_convs)

        min_ratio, max_ratio = basesize_ratio_range
        min_ratio = int(min_ratio * 100)
        max_ratio = int(max_ratio * 100)
        step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
        min_sizes = []
        max_sizes = []
        for r in range(int(min_ratio), int(max_ratio) + 1, step):
            min_sizes.append(int(input_size * r / 100))
            max_sizes.append(int(input_size * (r + step) / 100))
yhcao6's avatar
yhcao6 committed
59
        if input_size == 300:
60
            if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
yhcao6's avatar
yhcao6 committed
61
62
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
63
            elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
yhcao6's avatar
yhcao6 committed
64
65
                min_sizes.insert(0, int(input_size * 10 / 100))
                max_sizes.insert(0, int(input_size * 20 / 100))
yhcao6's avatar
yhcao6 committed
66
        elif input_size == 512:
67
68
69
70
71
72
            if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
                min_sizes.insert(0, int(input_size * 4 / 100))
                max_sizes.insert(0, int(input_size * 10 / 100))
            elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
yhcao6's avatar
yhcao6 committed
73
74
75
76
        self.anchor_generators = []
        self.anchor_strides = anchor_strides
        for k in range(len(anchor_strides)):
            base_size = min_sizes[k]
77
78
            stride = anchor_strides[k]
            ctr = ((stride - 1) / 2., (stride - 1) / 2.)
yhcao6's avatar
yhcao6 committed
79
80
81
82
83
            scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
            ratios = [1.]
            for r in anchor_ratios[k]:
                ratios += [1 / r, r]  # 4 or 6 ratio
            anchor_generator = AnchorGenerator(
84
                base_size, scales, ratios, scale_major=False, ctr=ctr)
yhcao6's avatar
yhcao6 committed
85
86
87
88
89
90
91
92
            indices = list(range(len(ratios)))
            indices.insert(1, len(indices))
            anchor_generator.base_anchors = torch.index_select(
                anchor_generator.base_anchors, 0, torch.LongTensor(indices))
            self.anchor_generators.append(anchor_generator)

        self.target_means = target_means
        self.target_stds = target_stds
93
        self.use_sigmoid_cls = False
94
        self.cls_focal_loss = False
yhcao6's avatar
yhcao6 committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform', bias=0)

    def forward(self, feats):
        cls_scores = []
        bbox_preds = []
        for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
                                            self.cls_convs):
            cls_scores.append(cls_conv(feat))
            bbox_preds.append(reg_conv(feat))
        return cls_scores, bbox_preds

    def loss_single(self, cls_score, bbox_pred, labels, label_weights,
111
                    bbox_targets, bbox_weights, num_total_samples, cfg):
yhcao6's avatar
yhcao6 committed
112
113
        loss_cls_all = F.cross_entropy(
            cls_score, labels, reduction='none') * label_weights
114
115
116
117
118
119
120
121
122
        pos_inds = (labels > 0).nonzero().view(-1)
        neg_inds = (labels == 0).nonzero().view(-1)

        num_pos_samples = pos_inds.size(0)
        num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
        if num_neg_samples > neg_inds.size(0):
            num_neg_samples = neg_inds.size(0)
        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
        loss_cls_pos = loss_cls_all[pos_inds].sum()
yhcao6's avatar
yhcao6 committed
123
        loss_cls_neg = topk_loss_cls_neg.sum()
124
        loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
yhcao6's avatar
yhcao6 committed
125

Kai Chen's avatar
Kai Chen committed
126
        loss_bbox = smooth_l1_loss(
yhcao6's avatar
yhcao6 committed
127
128
129
130
            bbox_pred,
            bbox_targets,
            bbox_weights,
            beta=cfg.smoothl1_beta,
131
            avg_factor=num_total_samples)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
132
        return loss_cls[None], loss_bbox
yhcao6's avatar
yhcao6 committed
133

134
135
136
137
138
139
140
141
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
yhcao6's avatar
yhcao6 committed
142
143
144
145
146
147
148
149
150
151
152
153
154
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
155
            gt_bboxes_ignore_list=gt_bboxes_ignore,
yhcao6's avatar
yhcao6 committed
156
            gt_labels_list=gt_labels,
Kai Chen's avatar
Kai Chen committed
157
            label_channels=1,
yhcao6's avatar
yhcao6 committed
158
159
160
161
162
163
164
165
166
            sampling=False,
            unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
167
            s.permute(0, 2, 3, 1).reshape(
yhcao6's avatar
yhcao6 committed
168
169
170
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
171
172
        all_label_weights = torch.cat(label_weights_list,
                                      -1).view(num_images, -1)
yhcao6's avatar
yhcao6 committed
173
        all_bbox_preds = torch.cat([
174
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
yhcao6's avatar
yhcao6 committed
175
176
            for b in bbox_preds
        ], -2)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
177
178
179
180
        all_bbox_targets = torch.cat(bbox_targets_list,
                                     -2).view(num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list,
                                     -2).view(num_images, -1, 4)
yhcao6's avatar
yhcao6 committed
181

Jiangmiao Pang's avatar
Jiangmiao Pang committed
182
        losses_cls, losses_bbox = multi_apply(
yhcao6's avatar
yhcao6 committed
183
184
185
186
187
188
189
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
190
            num_total_samples=num_total_pos,
yhcao6's avatar
yhcao6 committed
191
            cfg=cfg)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
192
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)