ssd_head.py 7.4 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
6
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init

7
8
9
from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
                        multi_apply)
from .anchor_head import AnchorHead
Kai Chen's avatar
Kai Chen committed
10
from ..registry import HEADS
yhcao6's avatar
yhcao6 committed
11
12


Kai Chen's avatar
Kai Chen committed
13
@HEADS.register_module
14
class SSDHead(AnchorHead):
yhcao6's avatar
yhcao6 committed
15
16
17
18

    def __init__(self,
                 input_size=300,
                 num_classes=81,
19
                 in_channels=(512, 1024, 512, 256, 256, 256),
yhcao6's avatar
yhcao6 committed
20
21
22
23
24
                 anchor_strides=(8, 16, 32, 64, 100, 300),
                 basesize_ratio_range=(0.1, 0.9),
                 anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
                 target_means=(.0, .0, .0, .0),
                 target_stds=(1.0, 1.0, 1.0, 1.0)):
25
26
        super(AnchorHead, self).__init__()
        self.input_size = input_size
yhcao6's avatar
yhcao6 committed
27
        self.num_classes = num_classes
28
        self.in_channels = in_channels
yhcao6's avatar
yhcao6 committed
29
        self.cls_out_channels = num_classes
30
        num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
yhcao6's avatar
yhcao6 committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        reg_convs = []
        cls_convs = []
        for i in range(len(in_channels)):
            reg_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * 4,
                    kernel_size=3,
                    padding=1))
            cls_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * num_classes,
                    kernel_size=3,
                    padding=1))
        self.reg_convs = nn.ModuleList(reg_convs)
        self.cls_convs = nn.ModuleList(cls_convs)

        min_ratio, max_ratio = basesize_ratio_range
        min_ratio = int(min_ratio * 100)
        max_ratio = int(max_ratio * 100)
        step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
        min_sizes = []
        max_sizes = []
        for r in range(int(min_ratio), int(max_ratio) + 1, step):
            min_sizes.append(int(input_size * r / 100))
            max_sizes.append(int(input_size * (r + step) / 100))
yhcao6's avatar
yhcao6 committed
58
        if input_size == 300:
59
            if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
yhcao6's avatar
yhcao6 committed
60
61
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
62
            elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
yhcao6's avatar
yhcao6 committed
63
64
                min_sizes.insert(0, int(input_size * 10 / 100))
                max_sizes.insert(0, int(input_size * 20 / 100))
yhcao6's avatar
yhcao6 committed
65
        elif input_size == 512:
66
67
68
69
70
71
            if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
                min_sizes.insert(0, int(input_size * 4 / 100))
                max_sizes.insert(0, int(input_size * 10 / 100))
            elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
yhcao6's avatar
yhcao6 committed
72
73
74
75
        self.anchor_generators = []
        self.anchor_strides = anchor_strides
        for k in range(len(anchor_strides)):
            base_size = min_sizes[k]
76
77
            stride = anchor_strides[k]
            ctr = ((stride - 1) / 2., (stride - 1) / 2.)
yhcao6's avatar
yhcao6 committed
78
79
80
81
82
            scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
            ratios = [1.]
            for r in anchor_ratios[k]:
                ratios += [1 / r, r]  # 4 or 6 ratio
            anchor_generator = AnchorGenerator(
83
                base_size, scales, ratios, scale_major=False, ctr=ctr)
yhcao6's avatar
yhcao6 committed
84
85
86
87
88
89
90
91
            indices = list(range(len(ratios)))
            indices.insert(1, len(indices))
            anchor_generator.base_anchors = torch.index_select(
                anchor_generator.base_anchors, 0, torch.LongTensor(indices))
            self.anchor_generators.append(anchor_generator)

        self.target_means = target_means
        self.target_stds = target_stds
92
        self.use_sigmoid_cls = False
93
        self.cls_focal_loss = False
yhcao6's avatar
yhcao6 committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform', bias=0)

    def forward(self, feats):
        cls_scores = []
        bbox_preds = []
        for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
                                            self.cls_convs):
            cls_scores.append(cls_conv(feat))
            bbox_preds.append(reg_conv(feat))
        return cls_scores, bbox_preds

    def loss_single(self, cls_score, bbox_pred, labels, label_weights,
110
                    bbox_targets, bbox_weights, num_total_samples, cfg):
yhcao6's avatar
yhcao6 committed
111
112
        loss_cls_all = F.cross_entropy(
            cls_score, labels, reduction='none') * label_weights
113
114
115
116
117
118
119
120
121
        pos_inds = (labels > 0).nonzero().view(-1)
        neg_inds = (labels == 0).nonzero().view(-1)

        num_pos_samples = pos_inds.size(0)
        num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
        if num_neg_samples > neg_inds.size(0):
            num_neg_samples = neg_inds.size(0)
        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
        loss_cls_pos = loss_cls_all[pos_inds].sum()
yhcao6's avatar
yhcao6 committed
122
        loss_cls_neg = topk_loss_cls_neg.sum()
123
        loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
yhcao6's avatar
yhcao6 committed
124
125
126
127
128
129

        loss_reg = weighted_smoothl1(
            bbox_pred,
            bbox_targets,
            bbox_weights,
            beta=cfg.smoothl1_beta,
130
            avg_factor=num_total_samples)
yhcao6's avatar
yhcao6 committed
131
132
        return loss_cls[None], loss_reg

133
134
135
136
137
138
139
140
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
yhcao6's avatar
yhcao6 committed
141
142
143
144
145
146
147
148
149
150
151
152
153
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
154
            gt_bboxes_ignore_list=gt_bboxes_ignore,
yhcao6's avatar
yhcao6 committed
155
            gt_labels_list=gt_labels,
Kai Chen's avatar
Kai Chen committed
156
            label_channels=1,
yhcao6's avatar
yhcao6 committed
157
158
159
160
161
162
163
164
165
            sampling=False,
            unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
166
            s.permute(0, 2, 3, 1).reshape(
yhcao6's avatar
yhcao6 committed
167
168
169
170
171
172
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        all_label_weights = torch.cat(label_weights_list, -1).view(
            num_images, -1)
        all_bbox_preds = torch.cat([
173
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
yhcao6's avatar
yhcao6 committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            for b in bbox_preds
        ], -2)
        all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
            num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list, -2).view(
            num_images, -1, 4)

        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
189
            num_total_samples=num_total_pos,
yhcao6's avatar
yhcao6 committed
190
191
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_reg=losses_reg)