anchor3d_head.py 18 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3
from typing import List, Tuple
4

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np
import torch
7
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
8
from torch import nn as nn
zhangwenwei's avatar
zhangwenwei committed
9

10
11
12
from mmdet3d.core import PseudoSampler, merge_aug_bboxes_3d
from mmdet3d.core.utils import ConfigType, InstanceList, OptConfigType
from mmdet3d.core.utils.typing import OptInstanceList
13
14
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply
15
from .base_3d_dense_head import Base3DDenseHead
zhangwenwei's avatar
zhangwenwei committed
16
17
18
from .train_mixins import AnchorTrainMixin


19
@MODELS.register_module()
20
21
class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin):
    """Anchor-based head for SECOND/PointPillars/MVXNet/PartA2.
22

zhangwenwei's avatar
zhangwenwei committed
23
    Args:
zhangwenwei's avatar
zhangwenwei committed
24
        num_classes (int): Number of classes.
zhangwenwei's avatar
zhangwenwei committed
25
26
        in_channels (int): Number of channels in the input feature map.
        feat_channels (int): Number of channels of the feature map.
27
28
29
30
31
32
33
        use_direction_classifier (bool): Whether to add a direction classifier.
        anchor_generator(dict): Config dict of anchor generator.
        assigner_per_size (bool): Whether to do assignment for each separate
            anchor size.
        assign_per_class (bool): Whether to do assignment for each class.
        diff_rad_by_sin (bool): Whether to change the difference into sin
            difference for box regression loss.
wuyuefeng's avatar
wuyuefeng committed
34
        dir_offset (float | int): The offset of BEV rotation angles.
35
            (TODO: may be moved into box coder)
wuyuefeng's avatar
wuyuefeng committed
36
37
38
        dir_limit_offset (float | int): The limited range of BEV
            rotation angles. (TODO: may be moved into box coder)
        bbox_coder (dict): Config dict of box coders.
zhangwenwei's avatar
zhangwenwei committed
39
40
        loss_cls (dict): Config of classification loss.
        loss_bbox (dict): Config of localization loss.
41
        loss_dir (dict): Config of direction classifier loss.
42
43
44
        train_cfg (dict): Train configs.
        test_cfg (dict): Test configs.
        init_cfg (dict or list[dict], optional): Initialization config dict.
zhangwenwei's avatar
zhangwenwei committed
45
    """
zhangwenwei's avatar
zhangwenwei committed
46
47

    def __init__(self,
48
49
50
51
                 num_classes: int,
                 in_channels: int,
                 feat_channels: int = 256,
                 use_direction_classifier: bool = True,
52
                 anchor_generator: ConfigType = dict(
53
54
55
                     type='Anchor3DRangeGenerator',
                     range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
                     strides=[2],
56
                     sizes=[[3.9, 1.6, 1.56]],
57
58
59
                     rotations=[0, 1.57],
                     custom_values=[],
                     reshape_out=False),
60
61
62
63
64
                 assigner_per_size: bool = False,
                 assign_per_class: bool = False,
                 diff_rad_by_sin: bool = True,
                 dir_offset: float = -np.pi / 2,
                 dir_limit_offset: int = 0,
65
66
67
                 bbox_coder: ConfigType = dict(type='DeltaXYZWLHRBBoxCoder'),
                 loss_cls: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
zhangwenwei's avatar
zhangwenwei committed
68
69
                     use_sigmoid=True,
                     loss_weight=1.0),
70
71
72
73
74
75
76
77
78
                 loss_bbox: ConfigType = dict(
                     type='mmdet.SmoothL1Loss',
                     beta=1.0 / 9.0,
                     loss_weight=2.0),
                 loss_dir: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss', loss_weight=0.2),
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 init_cfg: OptConfigType = None) -> None:
79
        super().__init__(init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
80
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
81
        self.num_classes = num_classes
zhangwenwei's avatar
zhangwenwei committed
82
83
84
85
86
87
88
89
90
        self.feat_channels = feat_channels
        self.diff_rad_by_sin = diff_rad_by_sin
        self.use_direction_classifier = use_direction_classifier
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.assigner_per_size = assigner_per_size
        self.assign_per_class = assign_per_class
        self.dir_offset = dir_offset
        self.dir_limit_offset = dir_limit_offset
91
92
93
        warnings.warn(
            'dir_offset and dir_limit_offset will be depressed and be '
            'incorporated into box coder in the future')
94
        self.fp16_enabled = False
zhangwenwei's avatar
zhangwenwei committed
95
96

        # build anchor generator
97
        self.prior_generator = TASK_UTILS.build(anchor_generator)
zhangwenwei's avatar
zhangwenwei committed
98
        # In 3D detection, the anchor stride is connected with anchor size
99
        self.num_anchors = self.prior_generator.num_base_anchors
zhangwenwei's avatar
zhangwenwei committed
100
        # build box coder
101
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
zhangwenwei's avatar
zhangwenwei committed
102
        self.box_code_size = self.bbox_coder.code_size
zhangwenwei's avatar
zhangwenwei committed
103

zhangwenwei's avatar
zhangwenwei committed
104
        # build loss function
zhangwenwei's avatar
zhangwenwei committed
105
        self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
106
107
108
        self.sampling = loss_cls['type'] not in [
            'mmdet.FocalLoss', 'mmdet.GHMC'
        ]
zhangwenwei's avatar
zhangwenwei committed
109
110
        if not self.use_sigmoid_cls:
            self.num_classes += 1
111
112
113
        self.loss_cls = MODELS.build(loss_cls)
        self.loss_bbox = MODELS.build(loss_bbox)
        self.loss_dir = MODELS.build(loss_dir)
zhangwenwei's avatar
zhangwenwei committed
114
115
        self.fp16_enabled = False

zhangwenwei's avatar
zhangwenwei committed
116
117
118
        self._init_layers()
        self._init_assigner_sampler()

119
120
121
122
123
124
125
126
        if init_cfg is None:
            self.init_cfg = dict(
                type='Normal',
                layer='Conv2d',
                std=0.01,
                override=dict(
                    type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))

zhangwenwei's avatar
zhangwenwei committed
127
    def _init_assigner_sampler(self):
128
        """Initialize the target assigner and sampler of the head."""
zhangwenwei's avatar
zhangwenwei committed
129
130
131
132
        if self.train_cfg is None:
            return

        if self.sampling:
133
            self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
zhangwenwei's avatar
zhangwenwei committed
134
135
136
        else:
            self.bbox_sampler = PseudoSampler()
        if isinstance(self.train_cfg.assigner, dict):
137
            self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
zhangwenwei's avatar
zhangwenwei committed
138
139
        elif isinstance(self.train_cfg.assigner, list):
            self.bbox_assigner = [
140
                TASK_UTILS.build(res) for res in self.train_cfg.assigner
zhangwenwei's avatar
zhangwenwei committed
141
142
            ]

zhangwenwei's avatar
zhangwenwei committed
143
    def _init_layers(self):
144
        """Initialize neural network layers of the head."""
zhangwenwei's avatar
zhangwenwei committed
145
146
147
148
149
150
151
152
        self.cls_out_channels = self.num_anchors * self.num_classes
        self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
        self.conv_reg = nn.Conv2d(self.feat_channels,
                                  self.num_anchors * self.box_code_size, 1)
        if self.use_direction_classifier:
            self.conv_dir_cls = nn.Conv2d(self.feat_channels,
                                          self.num_anchors * 2, 1)

153
    def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
wuyuefeng's avatar
wuyuefeng committed
154
155
156
        """Forward function on a single-scale feature map.

        Args:
157
            x (Tensor): Features of a single scale level.
wuyuefeng's avatar
wuyuefeng committed
158
159

        Returns:
160
161
162
163
164
165
166
167
            tuple:
                cls_score (Tensor): Cls scores for a single scale level
                    the channels number is num_base_priors * num_classes.
                bbox_pred (Tensor): Box energies / deltas for a single scale
                    level, the channels number is num_base_priors * C.
                dir_cls_pred (Tensor | None): Direction classification
                    prediction for a single scale level, the channels
                    number is num_base_priors * 2.
wuyuefeng's avatar
wuyuefeng committed
168
        """
zhangwenwei's avatar
zhangwenwei committed
169
170
        cls_score = self.conv_cls(x)
        bbox_pred = self.conv_reg(x)
171
        dir_cls_pred = None
zhangwenwei's avatar
zhangwenwei committed
172
        if self.use_direction_classifier:
173
174
            dir_cls_pred = self.conv_dir_cls(x)
        return cls_score, bbox_pred, dir_cls_pred
zhangwenwei's avatar
zhangwenwei committed
175

176
    def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
wuyuefeng's avatar
wuyuefeng committed
177
178
179
        """Forward pass.

        Args:
180
181
            x (tuple[Tensor]): Features from the upstream network,
                each is a 4D-tensor.
182
183

        Returns:
184
185
186
187
188
189
190
191
192
193
194
195
            tuple: A tuple of classification scores, bbox and direction
                classification prediction.

                - cls_scores (list[Tensor]): Classification scores for all
                    scale levels, each is a 4D-tensor, the channels number
                    is num_base_priors * num_classes.
                - bbox_preds (list[Tensor]): Box energies / deltas for all
                    scale levels, each is a 4D-tensor, the channels number
                    is num_base_priors * C.
                - dir_cls_preds (list[Tensor|None]): Direction classification
                    predictions for all scale levels, each is a 4D-tensor,
                    the channels number is num_base_priors * 2.
196
        """
197
        return multi_apply(self.forward_single, x)
198

199
    # TODO: Support augmentation test
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def aug_test(self,
                 aug_batch_feats,
                 aug_batch_input_metas,
                 rescale=False,
                 **kwargs):
        aug_bboxes = []
        # only support aug_test for one sample
        for x, input_meta in zip(aug_batch_feats, aug_batch_input_metas):
            outs = self.forward(x)
            bbox_list = self.get_results(*outs, [input_meta], rescale=rescale)
            bbox_dict = dict(
                bboxes_3d=bbox_list[0].bboxes_3d,
                scores_3d=bbox_list[0].scores_3d,
                labels_3d=bbox_list[0].labels_3d)
            aug_bboxes.append(bbox_dict)
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, aug_batch_input_metas,
                                            self.test_cfg)
        return [merged_bboxes]

    def get_anchors(self,
                    featmap_sizes: List[tuple],
                    input_metas: List[dict],
                    device: str = 'cuda') -> list:
zhangwenwei's avatar
zhangwenwei committed
224
        """Get anchors according to feature map sizes.
zhangwenwei's avatar
zhangwenwei committed
225

zhangwenwei's avatar
zhangwenwei committed
226
227
228
        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            input_metas (list[dict]): contain pcd and img's meta info.
wangtai's avatar
wangtai committed
229
            device (str): device of current module.
zhangwenwei's avatar
zhangwenwei committed
230

zhangwenwei's avatar
zhangwenwei committed
231
        Returns:
232
            list[list[torch.Tensor]]: Anchors of each image, valid flags
wangtai's avatar
wangtai committed
233
                of each image.
zhangwenwei's avatar
zhangwenwei committed
234
235
236
237
        """
        num_imgs = len(input_metas)
        # since feature map sizes of all images are the same, we only compute
        # anchors for one time
238
        multi_level_anchors = self.prior_generator.grid_anchors(
239
            featmap_sizes, device=device)
zhangwenwei's avatar
zhangwenwei committed
240
241
242
        anchor_list = [multi_level_anchors for _ in range(num_imgs)]
        return anchor_list

243
244
245
246
247
    def _loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
                             dir_cls_pred: Tensor, labels: Tensor,
                             label_weights: Tensor, bbox_targets: Tensor,
                             bbox_weights: Tensor, dir_targets: Tensor,
                             dir_weights: Tensor, num_total_samples: int):
wuyuefeng's avatar
wuyuefeng committed
248
249
250
        """Calculate loss of Single-level results.

        Args:
251
252
253
            cls_score (Tensor): Class score in single-level.
            bbox_pred (Tensor): Bbox prediction in single-level.
            dir_cls_pred (Tensor): Predictions of direction class
wuyuefeng's avatar
wuyuefeng committed
254
                in single-level.
255
256
257
258
259
260
            labels (Tensor): Labels of class.
            label_weights (Tensor): Weights of class loss.
            bbox_targets (Tensor): Targets of bbox predictions.
            bbox_weights (Tensor): Weights of bbox loss.
            dir_targets (Tensor): Targets of direction predictions.
            dir_weights (Tensor): Weights of direction loss.
wuyuefeng's avatar
wuyuefeng committed
261
262
263
            num_total_samples (int): The number of valid samples.

        Returns:
264
            tuple[torch.Tensor]: Losses of class, bbox
liyinhao's avatar
liyinhao committed
265
                and direction, respectively.
wuyuefeng's avatar
wuyuefeng committed
266
        """
zhangwenwei's avatar
zhangwenwei committed
267
268
269
270
271
272
        # classification loss
        if num_total_samples is None:
            num_total_samples = int(cls_score.shape[0])
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
273
        assert labels.max().item() <= self.num_classes
zhangwenwei's avatar
zhangwenwei committed
274
275
276
277
        loss_cls = self.loss_cls(
            cls_score, labels, label_weights, avg_factor=num_total_samples)

        # regression loss
278
279
        bbox_pred = bbox_pred.permute(0, 2, 3,
                                      1).reshape(-1, self.box_code_size)
zhangwenwei's avatar
zhangwenwei committed
280
281
282
        bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
        bbox_weights = bbox_weights.reshape(-1, self.box_code_size)

283
284
        bg_class_ind = self.num_classes
        pos_inds = ((labels >= 0)
Wenhao Wu's avatar
Wenhao Wu committed
285
286
                    & (labels < bg_class_ind)).nonzero(
                        as_tuple=False).reshape(-1)
287
288
289
290
291
292
293
        num_pos = len(pos_inds)

        pos_bbox_pred = bbox_pred[pos_inds]
        pos_bbox_targets = bbox_targets[pos_inds]
        pos_bbox_weights = bbox_weights[pos_inds]

        # dir loss
zhangwenwei's avatar
zhangwenwei committed
294
        if self.use_direction_classifier:
295
            dir_cls_pred = dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)
zhangwenwei's avatar
zhangwenwei committed
296
297
            dir_targets = dir_targets.reshape(-1)
            dir_weights = dir_weights.reshape(-1)
298
            pos_dir_cls_pred = dir_cls_pred[pos_inds]
299
300
301
302
303
304
            pos_dir_targets = dir_targets[pos_inds]
            pos_dir_weights = dir_weights[pos_inds]

        if num_pos > 0:
            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
305
                pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor(
306
307
308
309
310
311
312
313
                    code_weight)
            if self.diff_rad_by_sin:
                pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
                    pos_bbox_pred, pos_bbox_targets)
            loss_bbox = self.loss_bbox(
                pos_bbox_pred,
                pos_bbox_targets,
                pos_bbox_weights,
zhangwenwei's avatar
zhangwenwei committed
314
315
                avg_factor=num_total_samples)

316
317
318
319
            # direction classification loss
            loss_dir = None
            if self.use_direction_classifier:
                loss_dir = self.loss_dir(
320
                    pos_dir_cls_pred,
321
322
323
324
325
326
                    pos_dir_targets,
                    pos_dir_weights,
                    avg_factor=num_total_samples)
        else:
            loss_bbox = pos_bbox_pred.sum()
            if self.use_direction_classifier:
327
                loss_dir = pos_dir_cls_pred.sum()
328

zhangwenwei's avatar
zhangwenwei committed
329
330
331
        return loss_cls, loss_bbox, loss_dir

    @staticmethod
332
    def add_sin_difference(boxes1: Tensor, boxes2: Tensor) -> tuple:
zhangwenwei's avatar
zhangwenwei committed
333
        """Convert the rotation difference to difference in sine function.
zhangwenwei's avatar
zhangwenwei committed
334
335

        Args:
zhangwenwei's avatar
zhangwenwei committed
336
337
338
339
            boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
                and the 7th dimension is rotation dimension.
            boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
                the 7th dimension is rotation dimension.
zhangwenwei's avatar
zhangwenwei committed
340
341

        Returns:
342
            tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th
zhangwenwei's avatar
zhangwenwei committed
343
                dimensions are changed.
zhangwenwei's avatar
zhangwenwei committed
344
345
346
347
348
349
350
351
352
        """
        rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
            boxes2[..., 6:7])
        rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[...,
                                                                         6:7])
        boxes1 = torch.cat(
            [boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1)
        boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]],
                           dim=-1)
zhangwenwei's avatar
zhangwenwei committed
353
354
        return boxes1, boxes2

355
356
357
358
359
360
361
362
363
364
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            dir_cls_preds: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_input_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
        """Calculate the loss based on the features extracted by the detection
        head.
wuyuefeng's avatar
wuyuefeng committed
365
366

        Args:
liyinhao's avatar
liyinhao committed
367
368
369
            cls_scores (list[torch.Tensor]): Multi-level class scores.
            bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
            dir_cls_preds (list[torch.Tensor]): Multi-level direction
wuyuefeng's avatar
wuyuefeng committed
370
                class predictions.
371
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
372
373
                gt_instances. It usually includes ``bboxes_3d``
                and ``labels_3d`` attributes.
374
375
376
377
378
            batch_input_metas (list[dict]): Contain pcd and img's meta info.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
wuyuefeng's avatar
wuyuefeng committed
379
380

        Returns:
381
            dict[str, list[torch.Tensor]]: Classification, bbox, and
zhangwenwei's avatar
zhangwenwei committed
382
                direction losses of each level.
383

384
385
                - loss_cls (list[torch.Tensor]): Classification losses.
                - loss_bbox (list[torch.Tensor]): Box regression losses.
386
                - loss_dir (list[torch.Tensor]): Direction classification
387
                    losses.
wuyuefeng's avatar
wuyuefeng committed
388
        """
zhangwenwei's avatar
zhangwenwei committed
389
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
390
        assert len(featmap_sizes) == self.prior_generator.num_levels
391
392
        device = cls_scores[0].device
        anchor_list = self.get_anchors(
393
            featmap_sizes, batch_input_metas, device=device)
zhangwenwei's avatar
zhangwenwei committed
394
395
396
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = self.anchor_target_3d(
            anchor_list,
397
398
399
            batch_gt_instances_3d,
            batch_input_metas,
            batch_gt_instances_ignore=batch_gt_instances_ignore,
zhangwenwei's avatar
zhangwenwei committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
            num_classes=self.num_classes,
            label_channels=label_channels,
            sampling=self.sampling)

        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         dir_targets_list, dir_weights_list, num_total_pos,
         num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)

        # num_total_samples = None
        losses_cls, losses_bbox, losses_dir = multi_apply(
414
            self._loss_by_feat_single,
zhangwenwei's avatar
zhangwenwei committed
415
416
417
418
419
420
421
422
423
424
425
            cls_scores,
            bbox_preds,
            dir_cls_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            dir_targets_list,
            dir_weights_list,
            num_total_samples=num_total_samples)
        return dict(
zhangwenwei's avatar
zhangwenwei committed
426
            loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir)