ssd_head.py 7.2 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
6
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init

7
8
9
from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
                        multi_apply)
from .anchor_head import AnchorHead
yhcao6's avatar
yhcao6 committed
10
11


12
class SSDHead(AnchorHead):
yhcao6's avatar
yhcao6 committed
13
14
15
16

    def __init__(self,
                 input_size=300,
                 num_classes=81,
17
                 in_channels=(512, 1024, 512, 256, 256, 256),
yhcao6's avatar
yhcao6 committed
18
19
20
21
22
                 anchor_strides=(8, 16, 32, 64, 100, 300),
                 basesize_ratio_range=(0.1, 0.9),
                 anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
                 target_means=(.0, .0, .0, .0),
                 target_stds=(1.0, 1.0, 1.0, 1.0)):
23
24
        super(AnchorHead, self).__init__()
        self.input_size = input_size
yhcao6's avatar
yhcao6 committed
25
        self.num_classes = num_classes
26
        self.in_channels = in_channels
yhcao6's avatar
yhcao6 committed
27
        self.cls_out_channels = num_classes
28
        num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
yhcao6's avatar
yhcao6 committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        reg_convs = []
        cls_convs = []
        for i in range(len(in_channels)):
            reg_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * 4,
                    kernel_size=3,
                    padding=1))
            cls_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * num_classes,
                    kernel_size=3,
                    padding=1))
        self.reg_convs = nn.ModuleList(reg_convs)
        self.cls_convs = nn.ModuleList(cls_convs)

        min_ratio, max_ratio = basesize_ratio_range
        min_ratio = int(min_ratio * 100)
        max_ratio = int(max_ratio * 100)
        step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
        min_sizes = []
        max_sizes = []
        for r in range(int(min_ratio), int(max_ratio) + 1, step):
            min_sizes.append(int(input_size * r / 100))
            max_sizes.append(int(input_size * (r + step) / 100))
yhcao6's avatar
yhcao6 committed
56
        if input_size == 300:
57
            if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
yhcao6's avatar
yhcao6 committed
58
59
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
60
            elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
yhcao6's avatar
yhcao6 committed
61
62
                min_sizes.insert(0, int(input_size * 10 / 100))
                max_sizes.insert(0, int(input_size * 20 / 100))
yhcao6's avatar
yhcao6 committed
63
        elif input_size == 512:
64
65
66
67
68
69
            if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
                min_sizes.insert(0, int(input_size * 4 / 100))
                max_sizes.insert(0, int(input_size * 10 / 100))
            elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
yhcao6's avatar
yhcao6 committed
70
71
72
73
        self.anchor_generators = []
        self.anchor_strides = anchor_strides
        for k in range(len(anchor_strides)):
            base_size = min_sizes[k]
74
75
            stride = anchor_strides[k]
            ctr = ((stride - 1) / 2., (stride - 1) / 2.)
yhcao6's avatar
yhcao6 committed
76
77
78
79
80
            scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
            ratios = [1.]
            for r in anchor_ratios[k]:
                ratios += [1 / r, r]  # 4 or 6 ratio
            anchor_generator = AnchorGenerator(
81
                base_size, scales, ratios, scale_major=False, ctr=ctr)
yhcao6's avatar
yhcao6 committed
82
83
84
85
86
87
88
89
            indices = list(range(len(ratios)))
            indices.insert(1, len(indices))
            anchor_generator.base_anchors = torch.index_select(
                anchor_generator.base_anchors, 0, torch.LongTensor(indices))
            self.anchor_generators.append(anchor_generator)

        self.target_means = target_means
        self.target_stds = target_stds
90
91
        self.use_sigmoid_cls = False
        self.use_focal_loss = False
yhcao6's avatar
yhcao6 committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform', bias=0)

    def forward(self, feats):
        cls_scores = []
        bbox_preds = []
        for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
                                            self.cls_convs):
            cls_scores.append(cls_conv(feat))
            bbox_preds.append(reg_conv(feat))
        return cls_scores, bbox_preds

    def loss_single(self, cls_score, bbox_pred, labels, label_weights,
108
                    bbox_targets, bbox_weights, num_total_samples, cfg):
yhcao6's avatar
yhcao6 committed
109
110
        loss_cls_all = F.cross_entropy(
            cls_score, labels, reduction='none') * label_weights
111
112
113
114
115
116
117
118
119
        pos_inds = (labels > 0).nonzero().view(-1)
        neg_inds = (labels == 0).nonzero().view(-1)

        num_pos_samples = pos_inds.size(0)
        num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
        if num_neg_samples > neg_inds.size(0):
            num_neg_samples = neg_inds.size(0)
        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
        loss_cls_pos = loss_cls_all[pos_inds].sum()
yhcao6's avatar
yhcao6 committed
120
        loss_cls_neg = topk_loss_cls_neg.sum()
121
        loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
yhcao6's avatar
yhcao6 committed
122
123
124
125
126
127

        loss_reg = weighted_smoothl1(
            bbox_pred,
            bbox_targets,
            bbox_weights,
            beta=cfg.smoothl1_beta,
128
            avg_factor=num_total_samples)
yhcao6's avatar
yhcao6 committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        return loss_cls[None], loss_reg

    def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
             cfg):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_labels_list=gt_labels,
Kai Chen's avatar
Kai Chen committed
147
            label_channels=1,
yhcao6's avatar
yhcao6 committed
148
149
150
151
152
153
154
155
156
            sampling=False,
            unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
157
            s.permute(0, 2, 3, 1).reshape(
yhcao6's avatar
yhcao6 committed
158
159
160
161
162
163
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        all_label_weights = torch.cat(label_weights_list, -1).view(
            num_images, -1)
        all_bbox_preds = torch.cat([
164
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
yhcao6's avatar
yhcao6 committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            for b in bbox_preds
        ], -2)
        all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
            num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list, -2).view(
            num_images, -1, 4)

        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
180
            num_total_samples=num_total_pos,
yhcao6's avatar
yhcao6 committed
181
182
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_reg=losses_reg)