"llama/ggml-cuda/sum.cu" did not exist on "05cd82ef94a5c4b14bc030dd93a94e18ed63e295"
ssd_head.py 7.84 KB
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
6
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init

Kai Chen's avatar
Kai Chen committed
7
8
from mmdet.core import AnchorGenerator, anchor_target, multi_apply
from ..losses import smooth_l1_loss
Kai Chen's avatar
Kai Chen committed
9
from ..registry import HEADS
10
from .anchor_head import AnchorHead
yhcao6's avatar
yhcao6 committed
11
12


Jiangmiao Pang's avatar
Jiangmiao Pang committed
13
# TODO: add loss evaluator for SSD
Kai Chen's avatar
Kai Chen committed
14
@HEADS.register_module
15
class SSDHead(AnchorHead):
yhcao6's avatar
yhcao6 committed
16
17
18
19

    def __init__(self,
                 input_size=300,
                 num_classes=81,
20
                 in_channels=(512, 1024, 512, 256, 256, 256),
yhcao6's avatar
yhcao6 committed
21
22
23
24
25
                 anchor_strides=(8, 16, 32, 64, 100, 300),
                 basesize_ratio_range=(0.1, 0.9),
                 anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
                 target_means=(.0, .0, .0, .0),
                 target_stds=(1.0, 1.0, 1.0, 1.0)):
26
27
        super(AnchorHead, self).__init__()
        self.input_size = input_size
yhcao6's avatar
yhcao6 committed
28
        self.num_classes = num_classes
29
        self.in_channels = in_channels
yhcao6's avatar
yhcao6 committed
30
        self.cls_out_channels = num_classes
31
        num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
yhcao6's avatar
yhcao6 committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        reg_convs = []
        cls_convs = []
        for i in range(len(in_channels)):
            reg_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * 4,
                    kernel_size=3,
                    padding=1))
            cls_convs.append(
                nn.Conv2d(
                    in_channels[i],
                    num_anchors[i] * num_classes,
                    kernel_size=3,
                    padding=1))
        self.reg_convs = nn.ModuleList(reg_convs)
        self.cls_convs = nn.ModuleList(cls_convs)

        min_ratio, max_ratio = basesize_ratio_range
        min_ratio = int(min_ratio * 100)
        max_ratio = int(max_ratio * 100)
        step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
        min_sizes = []
        max_sizes = []
        for r in range(int(min_ratio), int(max_ratio) + 1, step):
            min_sizes.append(int(input_size * r / 100))
            max_sizes.append(int(input_size * (r + step) / 100))
yhcao6's avatar
yhcao6 committed
59
        if input_size == 300:
60
            if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
yhcao6's avatar
yhcao6 committed
61
62
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
63
            elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
yhcao6's avatar
yhcao6 committed
64
65
                min_sizes.insert(0, int(input_size * 10 / 100))
                max_sizes.insert(0, int(input_size * 20 / 100))
yhcao6's avatar
yhcao6 committed
66
        elif input_size == 512:
67
68
69
70
71
72
            if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
                min_sizes.insert(0, int(input_size * 4 / 100))
                max_sizes.insert(0, int(input_size * 10 / 100))
            elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
                min_sizes.insert(0, int(input_size * 7 / 100))
                max_sizes.insert(0, int(input_size * 15 / 100))
yhcao6's avatar
yhcao6 committed
73
74
75
76
        self.anchor_generators = []
        self.anchor_strides = anchor_strides
        for k in range(len(anchor_strides)):
            base_size = min_sizes[k]
77
78
            stride = anchor_strides[k]
            ctr = ((stride - 1) / 2., (stride - 1) / 2.)
yhcao6's avatar
yhcao6 committed
79
80
81
82
83
            scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
            ratios = [1.]
            for r in anchor_ratios[k]:
                ratios += [1 / r, r]  # 4 or 6 ratio
            anchor_generator = AnchorGenerator(
84
                base_size, scales, ratios, scale_major=False, ctr=ctr)
yhcao6's avatar
yhcao6 committed
85
86
87
88
89
90
91
92
            indices = list(range(len(ratios)))
            indices.insert(1, len(indices))
            anchor_generator.base_anchors = torch.index_select(
                anchor_generator.base_anchors, 0, torch.LongTensor(indices))
            self.anchor_generators.append(anchor_generator)

        self.target_means = target_means
        self.target_stds = target_stds
93
        self.use_sigmoid_cls = False
94
        self.cls_focal_loss = False
Cao Yuhang's avatar
Cao Yuhang committed
95
        self.fp16_enabled = False
yhcao6's avatar
yhcao6 committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform', bias=0)

    def forward(self, feats):
        cls_scores = []
        bbox_preds = []
        for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
                                            self.cls_convs):
            cls_scores.append(cls_conv(feat))
            bbox_preds.append(reg_conv(feat))
        return cls_scores, bbox_preds

    def loss_single(self, cls_score, bbox_pred, labels, label_weights,
112
                    bbox_targets, bbox_weights, num_total_samples, cfg):
yhcao6's avatar
yhcao6 committed
113
114
        loss_cls_all = F.cross_entropy(
            cls_score, labels, reduction='none') * label_weights
115
116
117
118
119
120
121
122
123
        pos_inds = (labels > 0).nonzero().view(-1)
        neg_inds = (labels == 0).nonzero().view(-1)

        num_pos_samples = pos_inds.size(0)
        num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
        if num_neg_samples > neg_inds.size(0):
            num_neg_samples = neg_inds.size(0)
        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
        loss_cls_pos = loss_cls_all[pos_inds].sum()
yhcao6's avatar
yhcao6 committed
124
        loss_cls_neg = topk_loss_cls_neg.sum()
125
        loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
yhcao6's avatar
yhcao6 committed
126

Kai Chen's avatar
Kai Chen committed
127
        loss_bbox = smooth_l1_loss(
yhcao6's avatar
yhcao6 committed
128
129
130
131
            bbox_pred,
            bbox_targets,
            bbox_weights,
            beta=cfg.smoothl1_beta,
132
            avg_factor=num_total_samples)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
133
        return loss_cls[None], loss_bbox
yhcao6's avatar
yhcao6 committed
134

135
136
137
138
139
140
141
142
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
yhcao6's avatar
yhcao6 committed
143
144
145
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)

146
147
        device = cls_scores[0].device

yhcao6's avatar
yhcao6 committed
148
        anchor_list, valid_flag_list = self.get_anchors(
149
            featmap_sizes, img_metas, device=device)
yhcao6's avatar
yhcao6 committed
150
151
152
153
154
155
156
157
        cls_reg_targets = anchor_target(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
158
            gt_bboxes_ignore_list=gt_bboxes_ignore,
yhcao6's avatar
yhcao6 committed
159
            gt_labels_list=gt_labels,
Kai Chen's avatar
Kai Chen committed
160
            label_channels=1,
yhcao6's avatar
yhcao6 committed
161
162
163
164
165
166
167
168
169
            sampling=False,
            unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
170
            s.permute(0, 2, 3, 1).reshape(
yhcao6's avatar
yhcao6 committed
171
172
173
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
174
175
        all_label_weights = torch.cat(label_weights_list,
                                      -1).view(num_images, -1)
yhcao6's avatar
yhcao6 committed
176
        all_bbox_preds = torch.cat([
177
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
yhcao6's avatar
yhcao6 committed
178
179
            for b in bbox_preds
        ], -2)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
180
181
182
183
        all_bbox_targets = torch.cat(bbox_targets_list,
                                     -2).view(num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list,
                                     -2).view(num_images, -1, 4)
yhcao6's avatar
yhcao6 committed
184

185
186
187
188
189
190
        # check NaN and Inf
        assert torch.isfinite(all_cls_scores).all().item(), \
            'classification scores become infinite or NaN!'
        assert torch.isfinite(all_bbox_preds).all().item(), \
            'bbox predications become infinite or NaN!'

Jiangmiao Pang's avatar
Jiangmiao Pang committed
191
        losses_cls, losses_bbox = multi_apply(
yhcao6's avatar
yhcao6 committed
192
193
194
195
196
197
198
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
199
            num_total_samples=num_total_pos,
yhcao6's avatar
yhcao6 committed
200
            cfg=cfg)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
201
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)