"official/legacy/bert/serving.py" did not exist on "7ba713c985f9d28310489670dd086edbe8e103c4"
free_anchor3d_head.py 11.9 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Dict, List

zhangwenwei's avatar
zhangwenwei committed
4
import torch
5
from mmengine.device import get_device
6
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
7
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
8

9
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
10
11
from mmdet3d.structures import bbox_overlaps_nearest_3d
from mmdet3d.utils import InstanceList, OptInstanceList
zhangwenwei's avatar
zhangwenwei committed
12
13
14
15
from .anchor3d_head import Anchor3DHead
from .train_mixins import get_direction_target


16
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
17
class FreeAnchor3DHead(Anchor3DHead):
wangtai's avatar
wangtai committed
18
    r"""`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection.
zhangwenwei's avatar
zhangwenwei committed
19
20
21

    Note:
        This implementation is directly modified from the `mmdet implementation
22
        <https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/free_anchor_retina_head.py>`_.
zhangwenwei's avatar
zhangwenwei committed
23
24
25
26
27
28
29
30
31
32
        We find it also works on 3D detection with minor modification, i.e.,
        different hyper-parameters and a additional direction classifier.

    Args:
        pre_anchor_topk (int): Number of boxes that be token in each bag.
        bbox_thr (float): The threshold of the saturated linear function. It is
            usually the same with the IoU threshold used in NMS.
        gamma (float): Gamma parameter in focal loss.
        alpha (float): Alpha parameter in focal loss.
        kwargs (dict): Other arguments are the same as those in :class:`Anchor3DHead`.
33
    """  # noqa: E501
zhangwenwei's avatar
zhangwenwei committed
34
35

    def __init__(self,
36
37
38
39
40
41
                 pre_anchor_topk: int = 50,
                 bbox_thr: float = 0.6,
                 gamma: float = 2.0,
                 alpha: float = 0.5,
                 init_cfg: dict = None,
                 **kwargs) -> None:
42
        super().__init__(init_cfg=init_cfg, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
43
44
45
46
47
        self.pre_anchor_topk = pre_anchor_topk
        self.bbox_thr = bbox_thr
        self.gamma = gamma
        self.alpha = alpha

48
49
50
51
52
53
54
55
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            dir_cls_preds: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_input_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> Dict:
zhangwenwei's avatar
zhangwenwei committed
56
57
58
59
60
61
62
63
64
        """Calculate loss of FreeAnchor head.

        Args:
            cls_scores (list[torch.Tensor]): Classification scores of
                different samples.
            bbox_preds (list[torch.Tensor]): Box predictions of
                different samples
            dir_cls_preds (list[torch.Tensor]): Direction predictions of
                different samples
65
66
67
68
69
70
71
72
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instances. It usually includes ``bboxes_3d`` and
                ``labels_3d`` attributes.
            batch_input_metas (list[dict]): Contain pcd and img's meta info.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
73
74

        Returns:
75
            dict[str, torch.Tensor]: Loss items.
76
77
78

                - positive_bag_loss (torch.Tensor): Loss of positive samples.
                - negative_bag_loss (torch.Tensor): Loss of negative samples.
zhangwenwei's avatar
zhangwenwei committed
79
80
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
81
        assert len(featmap_sizes) == self.prior_generator.num_levels
zhangwenwei's avatar
zhangwenwei committed
82

83
84
85
        device = get_device()
        anchor_list = self.get_anchors(featmap_sizes, batch_input_metas,
                                       device)
86
        mlvl_anchors = [torch.cat(anchor) for anchor in anchor_list]
zhangwenwei's avatar
zhangwenwei committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

        # concatenate each level
        cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(
                cls_score.size(0), -1, self.num_classes)
            for cls_score in cls_scores
        ]
        bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(
                bbox_pred.size(0), -1, self.box_code_size)
            for bbox_pred in bbox_preds
        ]
        dir_cls_preds = [
            dir_cls_pred.permute(0, 2, 3,
                                 1).reshape(dir_cls_pred.size(0), -1, 2)
            for dir_cls_pred in dir_cls_preds
        ]

        cls_scores = torch.cat(cls_scores, dim=1)
        bbox_preds = torch.cat(bbox_preds, dim=1)
        dir_cls_preds = torch.cat(dir_cls_preds, dim=1)

109
        cls_probs = torch.sigmoid(cls_scores)
zhangwenwei's avatar
zhangwenwei committed
110
111
112
        box_prob = []
        num_pos = 0
        positive_losses = []
113
114
115
116
        for _, (anchors, gt_instance_3d, cls_prob, bbox_pred,
                dir_cls_pred) in enumerate(
                    zip(mlvl_anchors, batch_gt_instances_3d, cls_probs,
                        bbox_preds, dir_cls_preds)):
zhangwenwei's avatar
zhangwenwei committed
117

118
119
            gt_bboxes = gt_instance_3d.bboxes_3d.tensor.to(anchors.device)
            gt_labels = gt_instance_3d.labels_3d.to(anchors.device)
zhangwenwei's avatar
zhangwenwei committed
120
121
            with torch.no_grad():
                # box_localization: a_{j}^{loc}, shape: [j, 4]
122
                pred_boxes = self.bbox_coder.decode(anchors, bbox_pred)
zhangwenwei's avatar
zhangwenwei committed
123
124
125

                # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
                object_box_iou = bbox_overlaps_nearest_3d(
126
                    gt_bboxes, pred_boxes)
zhangwenwei's avatar
zhangwenwei committed
127
128
129
130

                # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
                t1 = self.bbox_thr
                t2 = object_box_iou.max(
131
                    dim=1, keepdim=True).values.clamp(min=t1 + 1e-6)
zhangwenwei's avatar
zhangwenwei committed
132
133
134
135
                object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
                    min=0, max=1)

                # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
136
                num_obj = gt_labels.size(0)
zhangwenwei's avatar
zhangwenwei committed
137
                indices = torch.stack(
138
                    [torch.arange(num_obj).type_as(gt_labels), gt_labels],
zhangwenwei's avatar
zhangwenwei committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                    dim=0)

                object_cls_box_prob = torch.sparse_coo_tensor(
                    indices, object_box_prob)

                # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
                """
                from "start" to "end" implement:
                image_box_iou = torch.sparse.max(object_cls_box_prob,
                                                 dim=0).t()

                """
                # start
                box_cls_prob = torch.sparse.sum(
                    object_cls_box_prob, dim=0).to_dense()

155
                indices = torch.nonzero(box_cls_prob, as_tuple=False).t_()
zhangwenwei's avatar
zhangwenwei committed
156
157
                if indices.numel() == 0:
                    image_box_prob = torch.zeros(
158
                        anchors.size(0),
zhangwenwei's avatar
zhangwenwei committed
159
160
161
                        self.num_classes).type_as(object_box_prob)
                else:
                    nonzero_box_prob = torch.where(
162
                        (gt_labels.unsqueeze(dim=-1) == indices[0]),
zhangwenwei's avatar
zhangwenwei committed
163
164
165
166
167
168
169
170
                        object_box_prob[:, indices[1]],
                        torch.tensor(
                            [0]).type_as(object_box_prob)).max(dim=0).values

                    # upmap to shape [j, c]
                    image_box_prob = torch.sparse_coo_tensor(
                        indices.flip([0]),
                        nonzero_box_prob,
171
                        size=(anchors.size(0), self.num_classes)).to_dense()
zhangwenwei's avatar
zhangwenwei committed
172
173
174
175
176
                # end

                box_prob.append(image_box_prob)

            # construct bags for objects
177
            match_quality_matrix = bbox_overlaps_nearest_3d(gt_bboxes, anchors)
zhangwenwei's avatar
zhangwenwei committed
178
179
180
181
182
183
184
185
186
            _, matched = torch.topk(
                match_quality_matrix,
                self.pre_anchor_topk,
                dim=1,
                sorted=False)
            del match_quality_matrix

            # matched_cls_prob: P_{ij}^{cls}
            matched_cls_prob = torch.gather(
187
188
189
                cls_prob[matched], 2,
                gt_labels.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
                                                1)).squeeze(2)
zhangwenwei's avatar
zhangwenwei committed
190
191

            # matched_box_prob: P_{ij}^{loc}
192
            matched_anchors = anchors[matched]
zhangwenwei's avatar
zhangwenwei committed
193
194
            matched_object_targets = self.bbox_coder.encode(
                matched_anchors,
195
                gt_bboxes.unsqueeze(dim=1).expand_as(matched_anchors))
zhangwenwei's avatar
zhangwenwei committed
196
197
198
199
200
201
202
203
204

            # direction classification loss
            loss_dir = None
            if self.use_direction_classifier:
                # also calculate direction prob: P_{ij}^{dir}
                matched_dir_targets = get_direction_target(
                    matched_anchors,
                    matched_object_targets,
                    self.dir_offset,
205
                    self.dir_limit_offset,
zhangwenwei's avatar
zhangwenwei committed
206
207
                    one_hot=False)
                loss_dir = self.loss_dir(
208
                    dir_cls_pred[matched].transpose(-2, -1),
zhangwenwei's avatar
zhangwenwei committed
209
210
211
212
213
                    matched_dir_targets,
                    reduction_override='none')

            # generate bbox weights
            if self.diff_rad_by_sin:
214
215
                bbox_preds_clone = bbox_pred.clone()
                bbox_preds_clone[matched], matched_object_targets = \
zhangwenwei's avatar
zhangwenwei committed
216
                    self.add_sin_difference(
217
                        bbox_preds_clone[matched], matched_object_targets)
zhangwenwei's avatar
zhangwenwei committed
218
219
220
221
222
223
224
            bbox_weights = matched_anchors.new_ones(matched_anchors.size())
            # Use pop is not right, check performance
            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
                bbox_weights = bbox_weights * bbox_weights.new_tensor(
                    code_weight)
            loss_bbox = self.loss_bbox(
225
                bbox_preds_clone[matched],
zhangwenwei's avatar
zhangwenwei committed
226
227
228
229
230
231
232
233
234
                matched_object_targets,
                bbox_weights,
                reduction_override='none').sum(-1)

            if loss_dir is not None:
                loss_bbox += loss_dir
            matched_box_prob = torch.exp(-loss_bbox)

            # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
235
            num_pos += len(gt_bboxes)
zhangwenwei's avatar
zhangwenwei committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            positive_losses.append(
                self.positive_bag_loss(matched_cls_prob, matched_box_prob))

        positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)

        # box_prob: P{a_{j} \in A_{+}}
        box_prob = torch.stack(box_prob, dim=0)

        # negative_loss:
        # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
        negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
            1, num_pos * self.pre_anchor_topk)

        losses = {
            'positive_bag_loss': positive_loss,
            'negative_bag_loss': negative_loss
        }
        return losses

255
256
    def positive_bag_loss(self, matched_cls_prob: Tensor,
                          matched_box_prob: Tensor) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
257
        """Generate positive bag loss.
zhangwenwei's avatar
zhangwenwei committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        Args:
            matched_cls_prob (torch.Tensor): Classification probability
                of matched positive samples.
            matched_box_prob (torch.Tensor): Bounding box probability
                of matched positive samples.

        Returns:
            torch.Tensor: Loss of positive samples.
        """
        # bag_prob = Mean-max(matched_prob)
        matched_prob = matched_cls_prob * matched_box_prob
        weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
        weight /= weight.sum(dim=1).unsqueeze(dim=-1)
        bag_prob = (weight * matched_prob).sum(dim=1)
        # positive_bag_loss = -self.alpha * log(bag_prob)
        bag_prob = bag_prob.clamp(0, 1)  # to avoid bug of BCE, check
        return self.alpha * F.binary_cross_entropy(
            bag_prob, torch.ones_like(bag_prob), reduction='none')

278
    def negative_bag_loss(self, cls_prob: Tensor, box_prob: Tensor) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
279
        """Generate negative bag loss.
zhangwenwei's avatar
zhangwenwei committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        Args:
            cls_prob (torch.Tensor): Classification probability
                of negative samples.
            box_prob (torch.Tensor): Bounding box probability
                of negative samples.

        Returns:
            torch.Tensor: Loss of negative samples.
        """
        prob = cls_prob * (1 - box_prob)
        prob = prob.clamp(0, 1)  # to avoid bug of BCE, check
        negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
            prob, torch.zeros_like(prob), reduction='none')
        return (1 - self.alpha) * negative_bag_loss