"tests/models/xlnet/test_modeling_xlnet.py" did not exist on "728637356c7ff9ab969f107d3c67790b4400bac9"
ssd_head.py 7.25 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
6
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init

7
8
9
from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
                        multi_apply)
from .anchor_head import AnchorHead
Kai Chen's avatar
Kai Chen committed
10
from ..registry import HEADS
yhcao6's avatar
yhcao6 committed
11
12


Kai Chen's avatar
Kai Chen committed
13
@HEADS.register_module
14
class SSDHead(AnchorHead):
yhcao6's avatar
yhcao6 committed
15
16
17
18

    def __init__(self,
                 input_size=300,
                 num_classes=81,
19
                 in_channels=(512, 1024, 512, 256, 256, 256),
yhcao6's avatar
yhcao6 committed
20
21
22
23
24
                 anchor_strides=(8, 16, 32, 64, 100, 300),
                 basesize_ratio_range=(0.1, 0.9),
                 anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
                 target_means=(.0, .0, .0, .0),
                 target_stds=(1.0, 1.0, 1.0, 1.0)):
25
26
        super(AnchorHead, self).__init__()
        self.input_size = input_size
yhcao6's avatar
yhcao6 committed
27
        self.num_classes = num_classes
28
        self.in_channels = in_channels
yhcao6's avatar
yhcao6 committed
29
        self.cls_out_channels = num_classes
30
        num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
yhcao6's avatar
yhcao6 committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        reg_convs = []
        cls_convs = []
        for i in range(len(in_channels)):
            reg_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * 4,
                    kernel_size=3,
                    padding=1))
            cls_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * num_classes,
                    kernel_size=3,
                    padding=1))
        self.reg_convs = nn.ModuleList(reg_convs)
        self.cls_convs = nn.ModuleList(cls_convs)

        min_ratio, max_ratio = basesize_ratio_range
        min_ratio = int(min_ratio * 100)
        max_ratio = int(max_ratio * 100)
        step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
        min_sizes = []
        max_sizes = []
        for r in range(int(min_ratio), int(max_ratio) + 1, step):
            min_sizes.append(int(input_size * r / 100))
            max_sizes.append(int(input_size * (r + step) / 100))
yhcao6's avatar
yhcao6 committed
58
        if input_size == 300:
59
            if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
yhcao6's avatar
yhcao6 committed
60
61
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
62
            elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
yhcao6's avatar
yhcao6 committed
63
64
                min_sizes.insert(0, int(input_size * 10 / 100))
                max_sizes.insert(0, int(input_size * 20 / 100))
yhcao6's avatar
yhcao6 committed
65
        elif input_size == 512:
66
67
68
69
70
71
            if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
                min_sizes.insert(0, int(input_size * 4 / 100))
                max_sizes.insert(0, int(input_size * 10 / 100))
            elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
yhcao6's avatar
yhcao6 committed
72
73
74
75
        self.anchor_generators = []
        self.anchor_strides = anchor_strides
        for k in range(len(anchor_strides)):
            base_size = min_sizes[k]
76
77
            stride = anchor_strides[k]
            ctr = ((stride - 1) / 2., (stride - 1) / 2.)
yhcao6's avatar
yhcao6 committed
78
79
80
81
82
            scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
            ratios = [1.]
            for r in anchor_ratios[k]:
                ratios += [1 / r, r]  # 4 or 6 ratio
            anchor_generator = AnchorGenerator(
83
                base_size, scales, ratios, scale_major=False, ctr=ctr)
yhcao6's avatar
yhcao6 committed
84
85
86
87
88
89
90
91
            indices = list(range(len(ratios)))
            indices.insert(1, len(indices))
            anchor_generator.base_anchors = torch.index_select(
                anchor_generator.base_anchors, 0, torch.LongTensor(indices))
            self.anchor_generators.append(anchor_generator)

        self.target_means = target_means
        self.target_stds = target_stds
92
93
        self.use_sigmoid_cls = False
        self.use_focal_loss = False
yhcao6's avatar
yhcao6 committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform', bias=0)

    def forward(self, feats):
        cls_scores = []
        bbox_preds = []
        for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
                                            self.cls_convs):
            cls_scores.append(cls_conv(feat))
            bbox_preds.append(reg_conv(feat))
        return cls_scores, bbox_preds

    def loss_single(self, cls_score, bbox_pred, labels, label_weights,
110
                    bbox_targets, bbox_weights, num_total_samples, cfg):
yhcao6's avatar
yhcao6 committed
111
112
        loss_cls_all = F.cross_entropy(
            cls_score, labels, reduction='none') * label_weights
113
114
115
116
117
118
119
120
121
        pos_inds = (labels > 0).nonzero().view(-1)
        neg_inds = (labels == 0).nonzero().view(-1)

        num_pos_samples = pos_inds.size(0)
        num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
        if num_neg_samples > neg_inds.size(0):
            num_neg_samples = neg_inds.size(0)
        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
        loss_cls_pos = loss_cls_all[pos_inds].sum()
yhcao6's avatar
yhcao6 committed
122
        loss_cls_neg = topk_loss_cls_neg.sum()
123
        loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
yhcao6's avatar
yhcao6 committed
124
125
126
127
128
129

        loss_reg = weighted_smoothl1(
            bbox_pred,
            bbox_targets,
            bbox_weights,
            beta=cfg.smoothl1_beta,
130
            avg_factor=num_total_samples)
yhcao6's avatar
yhcao6 committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        return loss_cls[None], loss_reg

    def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
             cfg):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_labels_list=gt_labels,
Kai Chen's avatar
Kai Chen committed
149
            label_channels=1,
yhcao6's avatar
yhcao6 committed
150
151
152
153
154
155
156
157
158
            sampling=False,
            unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
159
            s.permute(0, 2, 3, 1).reshape(
yhcao6's avatar
yhcao6 committed
160
161
162
163
164
165
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        all_label_weights = torch.cat(label_weights_list, -1).view(
            num_images, -1)
        all_bbox_preds = torch.cat([
166
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
yhcao6's avatar
yhcao6 committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            for b in bbox_preds
        ], -2)
        all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
            num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list, -2).view(
            num_images, -1, 4)

        losses_cls, losses_reg = multi_apply(
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
182
            num_total_samples=num_total_pos,
yhcao6's avatar
yhcao6 committed
183
184
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_reg=losses_reg)