fcos2d.py 15.3 KB
Newer Older
lishj6's avatar
init  
lishj6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# Copyright 2021 Toyota Research Institute.  All rights reserved.
# Adapted from AdelaiDet:
#   https://github.com/aim-uofa/AdelaiDet
import torch
from fvcore.nn import sigmoid_focal_loss
from torch import nn
from torch.nn import functional as F

from detectron2.layers import Conv2d, batched_nms, cat, get_norm
from detectron2.structures import Boxes, Instances
from detectron2.utils.comm import get_world_size
from mmcv.runner import force_fp32

from projects.mmdet3d_plugin.dd3d.layers.iou_loss import IOULoss
from projects.mmdet3d_plugin.dd3d.layers.normalization import ModuleListDial, Scale
from projects.mmdet3d_plugin.dd3d.utils.comm import reduce_sum

INF = 100000000


def compute_ctrness_targets(reg_targets):
    if len(reg_targets) == 0:
        return reg_targets.new_zeros(len(reg_targets))
    left_right = reg_targets[:, [0, 2]]
    top_bottom = reg_targets[:, [1, 3]]
    ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                 (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
    return torch.sqrt(ctrness)


class FCOS2DHead(nn.Module):
    def __init__(self, 
                 num_classes, 
                 input_shape,
                 num_cls_convs=4,
                 num_box_convs=4,
                 norm='BN',
                 use_deformable=False,
                 use_scale=True,
                 box2d_scale_init_factor=1.0,
                 version='v2'):
        super().__init__()

        self.num_classes = num_classes
        self.in_strides = [shape.stride for shape in input_shape]
        self.num_levels = len(input_shape)

        self.use_scale = use_scale
        self.box2d_scale_init_factor = box2d_scale_init_factor

        self._version = version

        in_channels = [s.channels for s in input_shape]
        assert len(set(in_channels)) == 1, "Each level must have the same channel!"
        in_channels = in_channels[0]

        if use_deformable:
            raise ValueError("Not supported yet.")

        head_configs = {'cls': num_cls_convs, 'box2d': num_box_convs}

        for head_name, num_convs in head_configs.items():
            tower = []
            if self._version == "v1":
                for _ in range(num_convs):
                    conv_func = nn.Conv2d
                    tower.append(conv_func(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=True))
                    if norm == "GN":
                        raise NotImplementedError()
                    elif norm == "NaiveGN":
                        raise NotImplementedError()
                    elif norm == "BN":
                        tower.append(ModuleListDial([nn.BatchNorm2d(in_channels) for _ in range(self.num_levels)]))
                    elif norm == "SyncBN":
                        raise NotImplementedError()
                    tower.append(nn.ReLU())
            elif self._version == "v2":
                for _ in range(num_convs):
                    if norm in ("BN", "FrozenBN", "SyncBN", "GN"):
                        # NOTE: need to add norm here!
                        # Each FPN level has its own batchnorm layer.
                        # NOTE: do not use dd3d train.py!
                        # "BN" is converted to "SyncBN" in distributed training (see train.py)
                        norm_layer = ModuleListDial([get_norm(norm, in_channels) for _ in range(self.num_levels)])
                    else:
                        norm_layer = get_norm(norm, in_channels)
                    tower.append(
                        Conv2d(
                            in_channels,
                            in_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            bias=norm_layer is None,
                            norm=norm_layer,
                            activation=F.relu
                        )
                    )
            else:
                raise ValueError(f"Invalid FCOS2D version: {self._version}")
            self.add_module(f'{head_name}_tower', nn.Sequential(*tower))

        self.cls_logits = nn.Conv2d(in_channels, self.num_classes, kernel_size=3, stride=1, padding=1)
        self.box2d_reg = nn.Conv2d(in_channels, 4, kernel_size=3, stride=1, padding=1)
        self.centerness = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1)

        if self.use_scale:
            if self._version == "v1":
                self.scales_reg = nn.ModuleList([
                    Scale(init_value=stride * self.box2d_scale_init_factor) for stride in self.in_strides
                ])
            else:
                self.scales_box2d_reg = nn.ModuleList([
                    Scale(init_value=stride * self.box2d_scale_init_factor) for stride in self.in_strides
                ])

        self.init_weights()

    def init_weights(self):

        for tower in [self.cls_tower, self.box2d_tower]:
            for l in tower.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.kaiming_normal_(l.weight, mode='fan_out', nonlinearity='relu')
                    if l.bias is not None:
                        torch.nn.init.constant_(l.bias, 0)

        predictors = [self.cls_logits, self.box2d_reg, self.centerness]

        for modules in predictors:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.kaiming_uniform_(l.weight, a=1)
                    if l.bias is not None:  # depth head may not have bias.
                        torch.nn.init.constant_(l.bias, 0)

    def forward(self, x):
        logits = []
        box2d_reg = []
        centerness = []

        extra_output = {"cls_tower_out": []}

        for l, feature in enumerate(x):
            cls_tower_out = self.cls_tower(feature)
            bbox_tower_out = self.box2d_tower(feature)

            # 2D box
            logits.append(self.cls_logits(cls_tower_out))
            centerness.append(self.centerness(bbox_tower_out))
            box_reg = self.box2d_reg(bbox_tower_out)
            if self.use_scale:
                # TODO: to optimize the runtime, apply this scaling in inference (and loss compute) only on FG pixels?
                if self._version == "v1":
                    box_reg = self.scales_reg[l](box_reg)
                else:
                    box_reg = self.scales_box2d_reg[l](box_reg)
            # Note that we use relu, as in the improved FCOS, instead of exp.
            box2d_reg.append(F.relu(box_reg))

            extra_output['cls_tower_out'].append(cls_tower_out)

        return logits, box2d_reg, centerness, extra_output


class FCOS2DLoss(nn.Module):
    def __init__(self,
                 num_classes,
                 focal_loss_alpha=0.25,
                 focal_loss_gamma=2.0,
                 loc_loss_type='giou',
                 ):
        super().__init__()
        self.focal_loss_alpha = focal_loss_alpha
        self.focal_loss_gamma = focal_loss_gamma

        self.box2d_reg_loss_fn = IOULoss(loc_loss_type)

        self.num_classes = num_classes

    @force_fp32(apply_to=('logits', 'box2d_reg', 'centerness'))
    def forward(self, logits, box2d_reg, centerness, targets):
        labels = targets['labels']
        box2d_reg_targets = targets['box2d_reg_targets']
        pos_inds = targets["pos_inds"]

        if len(labels) != box2d_reg_targets.shape[0]:
            raise ValueError(
                f"The size of 'labels' and 'box2d_reg_targets' does not match: a={len(labels)}, b={box2d_reg_targets.shape[0]}"
            )

        # Flatten predictions
        logits = cat([x.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for x in logits])
        box2d_reg_pred = cat([x.permute(0, 2, 3, 1).reshape(-1, 4) for x in box2d_reg])
        centerness_pred = cat([x.permute(0, 2, 3, 1).reshape(-1) for x in centerness])

        # -------------------
        # Classification loss
        # -------------------
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        cls_target = torch.zeros_like(logits)
        cls_target[pos_inds, labels[pos_inds]] = 1

        loss_cls = sigmoid_focal_loss(
            logits,
            cls_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        # NOTE: The rest of losses only consider foreground pixels.
        box2d_reg_pred = box2d_reg_pred[pos_inds]
        box2d_reg_targets = box2d_reg_targets[pos_inds]

        centerness_pred = centerness_pred[pos_inds]

        # Compute centerness targets here using 2D regression targets of foreground pixels.
        centerness_targets = compute_ctrness_targets(box2d_reg_targets)

        # Denominator for all foreground losses.
        ctrness_targets_sum = centerness_targets.sum()
        loss_denom = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        # NOTE: change the return after reduce_sum
        if pos_inds.numel() == 0:
            losses = {
                "loss_cls": loss_cls,
                "loss_box2d_reg": box2d_reg_pred.sum() * 0.,
                "loss_centerness": centerness_pred.sum() * 0.,
            }
            return losses, {}

        # ----------------------
        # 2D box regression loss
        # ----------------------
        loss_box2d_reg = self.box2d_reg_loss_fn(box2d_reg_pred, box2d_reg_targets, centerness_targets) / loss_denom

        # ---------------
        # Centerness loss
        # ---------------
        loss_centerness = F.binary_cross_entropy_with_logits(
            centerness_pred, centerness_targets, reduction="sum"
        ) / num_pos_avg

        loss_dict = {"loss_cls": loss_cls, "loss_box2d_reg": loss_box2d_reg, "loss_centerness": loss_centerness}
        extra_info = {"loss_denom": loss_denom, "centerness_targets": centerness_targets}

        return loss_dict, extra_info


class FCOS2DInference():
    def __init__(self, cfg):
        self.thresh_with_ctr = cfg.DD3D.FCOS2D.INFERENCE.THRESH_WITH_CTR
        self.pre_nms_thresh = cfg.DD3D.FCOS2D.INFERENCE.PRE_NMS_THRESH
        self.pre_nms_topk = cfg.DD3D.FCOS2D.INFERENCE.PRE_NMS_TOPK
        self.post_nms_topk = cfg.DD3D.FCOS2D.INFERENCE.POST_NMS_TOPK
        self.nms_thresh = cfg.DD3D.FCOS2D.INFERENCE.NMS_THRESH
        self.num_classes = cfg.DD3D.NUM_CLASSES

    def __call__(self, logits, box2d_reg, centerness, locations, image_sizes):

        pred_instances = []  # List[List[Instances]], shape = (L, B)
        extra_info = []
        for lvl, (logits_lvl, box2d_reg_lvl, centerness_lvl, locations_lvl) in \
            enumerate(zip(logits, box2d_reg, centerness, locations)):

            instances_per_lvl, extra_info_per_lvl = self.forward_for_single_feature_map(
                logits_lvl, box2d_reg_lvl, centerness_lvl, locations_lvl, image_sizes
            )  # List of Instances; one for each image.

            for instances_per_im in instances_per_lvl:
                instances_per_im.fpn_levels = locations_lvl.new_ones(len(instances_per_im), dtype=torch.long) * lvl

            pred_instances.append(instances_per_lvl)
            extra_info.append(extra_info_per_lvl)

        return pred_instances, extra_info

    def forward_for_single_feature_map(self, logits, box2d_reg, centerness, locations, image_sizes):
        N, C, _, __ = logits.shape

        # put in the same format as locations
        scores = logits.permute(0, 2, 3, 1).reshape(N, -1, C).sigmoid()
        box2d_reg = box2d_reg.permute(0, 2, 3, 1).reshape(N, -1, 4)
        centerness = centerness.permute(0, 2, 3, 1).reshape(N, -1).sigmoid()

        # if self.thresh_with_ctr is True, we multiply the classification
        # scores with centerness scores before applying the threshold.
        if self.thresh_with_ctr:
            scores = scores * centerness[:, :, None]

        candidate_mask = scores > self.pre_nms_thresh

        pre_nms_topk = candidate_mask.reshape(N, -1).sum(1)
        pre_nms_topk = pre_nms_topk.clamp(max=self.pre_nms_topk)

        if not self.thresh_with_ctr:
            scores = scores * centerness[:, :, None]

        results = []
        all_fg_inds_per_im, all_topk_indices, all_class_inds_per_im = [], [], []
        for i in range(N):
            scores_per_im = scores[i]
            candidate_mask_per_im = candidate_mask[i]
            scores_per_im = scores_per_im[candidate_mask_per_im]

            candidate_inds_per_im = candidate_mask_per_im.nonzero(as_tuple=False)
            fg_inds_per_im = candidate_inds_per_im[:, 0]
            class_inds_per_im = candidate_inds_per_im[:, 1]

            # Cache info here.
            all_fg_inds_per_im.append(fg_inds_per_im)
            all_class_inds_per_im.append(class_inds_per_im)

            box2d_reg_per_im = box2d_reg[i][fg_inds_per_im]
            locations_per_im = locations[fg_inds_per_im]

            pre_nms_topk_per_im = pre_nms_topk[i]

            if candidate_mask_per_im.sum().item() > pre_nms_topk_per_im.item():
                scores_per_im, topk_indices = \
                    scores_per_im.topk(pre_nms_topk_per_im, sorted=False)

                class_inds_per_im = class_inds_per_im[topk_indices]
                box2d_reg_per_im = box2d_reg_per_im[topk_indices]
                locations_per_im = locations_per_im[topk_indices]
            else:
                topk_indices = None

            all_topk_indices.append(topk_indices)

            detections = torch.stack([
                locations_per_im[:, 0] - box2d_reg_per_im[:, 0],
                locations_per_im[:, 1] - box2d_reg_per_im[:, 1],
                locations_per_im[:, 0] + box2d_reg_per_im[:, 2],
                locations_per_im[:, 1] + box2d_reg_per_im[:, 3],
            ],
                                     dim=1)

            instances = Instances(image_sizes[i])
            instances.pred_boxes = Boxes(detections)
            instances.scores = torch.sqrt(scores_per_im)
            instances.pred_classes = class_inds_per_im
            instances.locations = locations_per_im

            results.append(instances)

        extra_info = {
            "fg_inds_per_im": all_fg_inds_per_im,
            "class_inds_per_im": all_class_inds_per_im,
            "topk_indices": all_topk_indices
        }
        return results, extra_info

    def nms_and_top_k(self, instances_per_im, score_key_for_nms="scores"):
        results = []
        for instances in instances_per_im:
            if self.nms_thresh > 0:
                # Multiclass NMS.
                keep = batched_nms(
                    instances.pred_boxes.tensor, instances.get(score_key_for_nms), instances.pred_classes,
                    self.nms_thresh
                )
                instances = instances[keep]
            num_detections = len(instances)

            # Limit to max_per_image detections **over all classes**
            if num_detections > self.post_nms_topk > 0:
                scores = instances.scores
                # image_thresh, _ = torch.kthvalue(scores.cpu(), num_detections - self.post_nms_topk + 1)
                image_thresh, _ = torch.kthvalue(scores, num_detections - self.post_nms_topk + 1)
                keep = scores >= image_thresh.item()
                keep = torch.nonzero(keep).squeeze(1)
                instances = instances[keep]
            results.append(instances)
        return results