free_anchor3d_head.py 11.8 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Dict, List

zhangwenwei's avatar
zhangwenwei committed
4
import torch
5
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
6
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
7

8
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
9
10
from mmdet3d.structures import bbox_overlaps_nearest_3d
from mmdet3d.utils import InstanceList, OptInstanceList
zhangwenwei's avatar
zhangwenwei committed
11
12
13
14
from .anchor3d_head import Anchor3DHead
from .train_mixins import get_direction_target


15
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
16
class FreeAnchor3DHead(Anchor3DHead):
wangtai's avatar
wangtai committed
17
    r"""`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection.
zhangwenwei's avatar
zhangwenwei committed
18
19
20

    Note:
        This implementation is directly modified from the `mmdet implementation
21
        <https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/free_anchor_retina_head.py>`_.
zhangwenwei's avatar
zhangwenwei committed
22
23
24
25
26
27
28
29
30
31
        We find it also works on 3D detection with minor modification, i.e.,
        different hyper-parameters and a additional direction classifier.

    Args:
        pre_anchor_topk (int): Number of boxes that be token in each bag.
        bbox_thr (float): The threshold of the saturated linear function. It is
            usually the same with the IoU threshold used in NMS.
        gamma (float): Gamma parameter in focal loss.
        alpha (float): Alpha parameter in focal loss.
        kwargs (dict): Other arguments are the same as those in :class:`Anchor3DHead`.
32
    """  # noqa: E501
zhangwenwei's avatar
zhangwenwei committed
33
34

    def __init__(self,
35
36
37
38
39
40
                 pre_anchor_topk: int = 50,
                 bbox_thr: float = 0.6,
                 gamma: float = 2.0,
                 alpha: float = 0.5,
                 init_cfg: dict = None,
                 **kwargs) -> None:
41
        super().__init__(init_cfg=init_cfg, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
42
43
44
45
46
        self.pre_anchor_topk = pre_anchor_topk
        self.bbox_thr = bbox_thr
        self.gamma = gamma
        self.alpha = alpha

47
48
49
50
51
52
53
54
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            dir_cls_preds: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_input_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> Dict:
zhangwenwei's avatar
zhangwenwei committed
55
56
57
58
59
60
61
62
63
        """Calculate loss of FreeAnchor head.

        Args:
            cls_scores (list[torch.Tensor]): Classification scores of
                different samples.
            bbox_preds (list[torch.Tensor]): Box predictions of
                different samples
            dir_cls_preds (list[torch.Tensor]): Direction predictions of
                different samples
64
65
66
67
68
69
70
71
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instances. It usually includes ``bboxes_3d`` and
                ``labels_3d`` attributes.
            batch_input_metas (list[dict]): Contain pcd and img's meta info.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
72
73

        Returns:
74
            dict[str, torch.Tensor]: Loss items.
75
76
77

                - positive_bag_loss (torch.Tensor): Loss of positive samples.
                - negative_bag_loss (torch.Tensor): Loss of negative samples.
zhangwenwei's avatar
zhangwenwei committed
78
79
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
80
        assert len(featmap_sizes) == self.prior_generator.num_levels
zhangwenwei's avatar
zhangwenwei committed
81

82
83
        anchor_list = self.get_anchors(featmap_sizes, batch_input_metas)
        mlvl_anchors = [torch.cat(anchor) for anchor in anchor_list]
zhangwenwei's avatar
zhangwenwei committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

        # concatenate each level
        cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(
                cls_score.size(0), -1, self.num_classes)
            for cls_score in cls_scores
        ]
        bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(
                bbox_pred.size(0), -1, self.box_code_size)
            for bbox_pred in bbox_preds
        ]
        dir_cls_preds = [
            dir_cls_pred.permute(0, 2, 3,
                                 1).reshape(dir_cls_pred.size(0), -1, 2)
            for dir_cls_pred in dir_cls_preds
        ]

        cls_scores = torch.cat(cls_scores, dim=1)
        bbox_preds = torch.cat(bbox_preds, dim=1)
        dir_cls_preds = torch.cat(dir_cls_preds, dim=1)

106
        cls_probs = torch.sigmoid(cls_scores)
zhangwenwei's avatar
zhangwenwei committed
107
108
109
        box_prob = []
        num_pos = 0
        positive_losses = []
110
111
112
113
        for _, (anchors, gt_instance_3d, cls_prob, bbox_pred,
                dir_cls_pred) in enumerate(
                    zip(mlvl_anchors, batch_gt_instances_3d, cls_probs,
                        bbox_preds, dir_cls_preds)):
zhangwenwei's avatar
zhangwenwei committed
114

115
116
            gt_bboxes = gt_instance_3d.bboxes_3d.tensor.to(anchors.device)
            gt_labels = gt_instance_3d.labels_3d.to(anchors.device)
zhangwenwei's avatar
zhangwenwei committed
117
118
            with torch.no_grad():
                # box_localization: a_{j}^{loc}, shape: [j, 4]
119
                pred_boxes = self.bbox_coder.decode(anchors, bbox_pred)
zhangwenwei's avatar
zhangwenwei committed
120
121
122

                # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
                object_box_iou = bbox_overlaps_nearest_3d(
123
                    gt_bboxes, pred_boxes)
zhangwenwei's avatar
zhangwenwei committed
124
125
126
127

                # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
                t1 = self.bbox_thr
                t2 = object_box_iou.max(
128
                    dim=1, keepdim=True).values.clamp(min=t1 + 1e-6)
zhangwenwei's avatar
zhangwenwei committed
129
130
131
132
                object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
                    min=0, max=1)

                # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
133
                num_obj = gt_labels.size(0)
zhangwenwei's avatar
zhangwenwei committed
134
                indices = torch.stack(
135
                    [torch.arange(num_obj).type_as(gt_labels), gt_labels],
zhangwenwei's avatar
zhangwenwei committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
                    dim=0)

                object_cls_box_prob = torch.sparse_coo_tensor(
                    indices, object_box_prob)

                # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
                """
                from "start" to "end" implement:
                image_box_iou = torch.sparse.max(object_cls_box_prob,
                                                 dim=0).t()

                """
                # start
                box_cls_prob = torch.sparse.sum(
                    object_cls_box_prob, dim=0).to_dense()

152
                indices = torch.nonzero(box_cls_prob, as_tuple=False).t_()
zhangwenwei's avatar
zhangwenwei committed
153
154
                if indices.numel() == 0:
                    image_box_prob = torch.zeros(
155
                        anchors.size(0),
zhangwenwei's avatar
zhangwenwei committed
156
157
158
                        self.num_classes).type_as(object_box_prob)
                else:
                    nonzero_box_prob = torch.where(
159
                        (gt_labels.unsqueeze(dim=-1) == indices[0]),
zhangwenwei's avatar
zhangwenwei committed
160
161
162
163
164
165
166
167
                        object_box_prob[:, indices[1]],
                        torch.tensor(
                            [0]).type_as(object_box_prob)).max(dim=0).values

                    # upmap to shape [j, c]
                    image_box_prob = torch.sparse_coo_tensor(
                        indices.flip([0]),
                        nonzero_box_prob,
168
                        size=(anchors.size(0), self.num_classes)).to_dense()
zhangwenwei's avatar
zhangwenwei committed
169
170
171
172
173
                # end

                box_prob.append(image_box_prob)

            # construct bags for objects
174
            match_quality_matrix = bbox_overlaps_nearest_3d(gt_bboxes, anchors)
zhangwenwei's avatar
zhangwenwei committed
175
176
177
178
179
180
181
182
183
            _, matched = torch.topk(
                match_quality_matrix,
                self.pre_anchor_topk,
                dim=1,
                sorted=False)
            del match_quality_matrix

            # matched_cls_prob: P_{ij}^{cls}
            matched_cls_prob = torch.gather(
184
185
186
                cls_prob[matched], 2,
                gt_labels.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
                                                1)).squeeze(2)
zhangwenwei's avatar
zhangwenwei committed
187
188

            # matched_box_prob: P_{ij}^{loc}
189
            matched_anchors = anchors[matched]
zhangwenwei's avatar
zhangwenwei committed
190
191
            matched_object_targets = self.bbox_coder.encode(
                matched_anchors,
192
                gt_bboxes.unsqueeze(dim=1).expand_as(matched_anchors))
zhangwenwei's avatar
zhangwenwei committed
193
194
195
196
197
198
199
200
201

            # direction classification loss
            loss_dir = None
            if self.use_direction_classifier:
                # also calculate direction prob: P_{ij}^{dir}
                matched_dir_targets = get_direction_target(
                    matched_anchors,
                    matched_object_targets,
                    self.dir_offset,
202
                    self.dir_limit_offset,
zhangwenwei's avatar
zhangwenwei committed
203
204
                    one_hot=False)
                loss_dir = self.loss_dir(
205
                    dir_cls_pred[matched].transpose(-2, -1),
zhangwenwei's avatar
zhangwenwei committed
206
207
208
209
210
                    matched_dir_targets,
                    reduction_override='none')

            # generate bbox weights
            if self.diff_rad_by_sin:
211
212
                bbox_preds_clone = bbox_pred.clone()
                bbox_preds_clone[matched], matched_object_targets = \
zhangwenwei's avatar
zhangwenwei committed
213
                    self.add_sin_difference(
214
                        bbox_preds_clone[matched], matched_object_targets)
zhangwenwei's avatar
zhangwenwei committed
215
216
217
218
219
220
221
            bbox_weights = matched_anchors.new_ones(matched_anchors.size())
            # Use pop is not right, check performance
            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
                bbox_weights = bbox_weights * bbox_weights.new_tensor(
                    code_weight)
            loss_bbox = self.loss_bbox(
222
                bbox_preds_clone[matched],
zhangwenwei's avatar
zhangwenwei committed
223
224
225
226
227
228
229
230
231
                matched_object_targets,
                bbox_weights,
                reduction_override='none').sum(-1)

            if loss_dir is not None:
                loss_bbox += loss_dir
            matched_box_prob = torch.exp(-loss_bbox)

            # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
232
            num_pos += len(gt_bboxes)
zhangwenwei's avatar
zhangwenwei committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            positive_losses.append(
                self.positive_bag_loss(matched_cls_prob, matched_box_prob))

        positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)

        # box_prob: P{a_{j} \in A_{+}}
        box_prob = torch.stack(box_prob, dim=0)

        # negative_loss:
        # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
        negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
            1, num_pos * self.pre_anchor_topk)

        losses = {
            'positive_bag_loss': positive_loss,
            'negative_bag_loss': negative_loss
        }
        return losses

252
253
    def positive_bag_loss(self, matched_cls_prob: Tensor,
                          matched_box_prob: Tensor) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
254
        """Generate positive bag loss.
zhangwenwei's avatar
zhangwenwei committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

        Args:
            matched_cls_prob (torch.Tensor): Classification probability
                of matched positive samples.
            matched_box_prob (torch.Tensor): Bounding box probability
                of matched positive samples.

        Returns:
            torch.Tensor: Loss of positive samples.
        """
        # bag_prob = Mean-max(matched_prob)
        matched_prob = matched_cls_prob * matched_box_prob
        weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
        weight /= weight.sum(dim=1).unsqueeze(dim=-1)
        bag_prob = (weight * matched_prob).sum(dim=1)
        # positive_bag_loss = -self.alpha * log(bag_prob)
        bag_prob = bag_prob.clamp(0, 1)  # to avoid bug of BCE, check
        return self.alpha * F.binary_cross_entropy(
            bag_prob, torch.ones_like(bag_prob), reduction='none')

275
    def negative_bag_loss(self, cls_prob: Tensor, box_prob: Tensor) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
276
        """Generate negative bag loss.
zhangwenwei's avatar
zhangwenwei committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        Args:
            cls_prob (torch.Tensor): Classification probability
                of negative samples.
            box_prob (torch.Tensor): Bounding box probability
                of negative samples.

        Returns:
            torch.Tensor: Loss of negative samples.
        """
        prob = cls_prob * (1 - box_prob)
        prob = prob.clamp(0, 1)  # to avoid bug of BCE, check
        negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
            prob, torch.zeros_like(prob), reduction='none')
        return (1 - self.alpha) * negative_bag_loss