anchor3d_head.py 27.8 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
import warnings
from typing import List, Optional, Tuple

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np
import torch
7
from mmcv import ConfigDict
8
from mmcv.runner import BaseModule, force_fp32
9
10
from mmengine.data import InstanceData
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
11
from torch import nn as nn
zhangwenwei's avatar
zhangwenwei committed
12

13
14
15
16
from mmdet3d.core import (Det3DDataSample, PseudoSampler, box3d_multiclass_nms,
                          limit_period, merge_aug_bboxes_3d, xywhr2xyxyr)
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply
zhangwenwei's avatar
zhangwenwei committed
17
18
19
from .train_mixins import AnchorTrainMixin


20
@MODELS.register_module()
21
class Anchor3DHead(BaseModule, AnchorTrainMixin):
zhangwenwei's avatar
zhangwenwei committed
22
    """Anchor head for SECOND/PointPillars/MVXNet/PartA2.
23

zhangwenwei's avatar
zhangwenwei committed
24
    Args:
zhangwenwei's avatar
zhangwenwei committed
25
        num_classes (int): Number of classes.
zhangwenwei's avatar
zhangwenwei committed
26
        in_channels (int): Number of channels in the input feature map.
wuyuefeng's avatar
wuyuefeng committed
27
28
        train_cfg (dict): Train configs.
        test_cfg (dict): Test configs.
zhangwenwei's avatar
zhangwenwei committed
29
        feat_channels (int): Number of channels of the feature map.
30
31
32
33
34
35
36
        use_direction_classifier (bool): Whether to add a direction classifier.
        anchor_generator(dict): Config dict of anchor generator.
        assigner_per_size (bool): Whether to do assignment for each separate
            anchor size.
        assign_per_class (bool): Whether to do assignment for each class.
        diff_rad_by_sin (bool): Whether to change the difference into sin
            difference for box regression loss.
wuyuefeng's avatar
wuyuefeng committed
37
        dir_offset (float | int): The offset of BEV rotation angles.
38
            (TODO: may be moved into box coder)
wuyuefeng's avatar
wuyuefeng committed
39
40
41
        dir_limit_offset (float | int): The limited range of BEV
            rotation angles. (TODO: may be moved into box coder)
        bbox_coder (dict): Config dict of box coders.
zhangwenwei's avatar
zhangwenwei committed
42
43
        loss_cls (dict): Config of classification loss.
        loss_bbox (dict): Config of localization loss.
44
        loss_dir (dict): Config of direction classifier loss.
zhangwenwei's avatar
zhangwenwei committed
45
    """
zhangwenwei's avatar
zhangwenwei committed
46
47

    def __init__(self,
48
49
50
51
52
53
54
                 num_classes: int,
                 in_channels: int,
                 train_cfg: dict,
                 test_cfg: dict,
                 feat_channels: int = 256,
                 use_direction_classifier: bool = True,
                 anchor_generator: dict = dict(
55
56
57
                     type='Anchor3DRangeGenerator',
                     range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
                     strides=[2],
58
                     sizes=[[3.9, 1.6, 1.56]],
59
60
61
                     rotations=[0, 1.57],
                     custom_values=[],
                     reshape_out=False),
62
63
64
65
66
67
68
                 assigner_per_size: bool = False,
                 assign_per_class: bool = False,
                 diff_rad_by_sin: bool = True,
                 dir_offset: float = -np.pi / 2,
                 dir_limit_offset: int = 0,
                 bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'),
                 loss_cls: dict = dict(
zhangwenwei's avatar
zhangwenwei committed
69
70
71
                     type='CrossEntropyLoss',
                     use_sigmoid=True,
                     loss_weight=1.0),
72
                 loss_bbox: dict = dict(
zhangwenwei's avatar
zhangwenwei committed
73
                     type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
74
75
76
                 loss_dir: dict = dict(
                     type='CrossEntropyLoss', loss_weight=0.2),
                 init_cfg: Optional[dict] = None) -> None:
77
        super().__init__(init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
78
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
79
        self.num_classes = num_classes
zhangwenwei's avatar
zhangwenwei committed
80
81
82
83
84
85
86
87
88
        self.feat_channels = feat_channels
        self.diff_rad_by_sin = diff_rad_by_sin
        self.use_direction_classifier = use_direction_classifier
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.assigner_per_size = assigner_per_size
        self.assign_per_class = assign_per_class
        self.dir_offset = dir_offset
        self.dir_limit_offset = dir_limit_offset
89
90
91
        warnings.warn(
            'dir_offset and dir_limit_offset will be depressed and be '
            'incorporated into box coder in the future')
92
        self.fp16_enabled = False
zhangwenwei's avatar
zhangwenwei committed
93
94

        # build anchor generator
95
        self.prior_generator = TASK_UTILS.build(anchor_generator)
zhangwenwei's avatar
zhangwenwei committed
96
        # In 3D detection, the anchor stride is connected with anchor size
97
        self.num_anchors = self.prior_generator.num_base_anchors
zhangwenwei's avatar
zhangwenwei committed
98
        # build box coder
99
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
zhangwenwei's avatar
zhangwenwei committed
100
        self.box_code_size = self.bbox_coder.code_size
zhangwenwei's avatar
zhangwenwei committed
101

zhangwenwei's avatar
zhangwenwei committed
102
        # build loss function
zhangwenwei's avatar
zhangwenwei committed
103
        self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
104
105
106
        self.sampling = loss_cls['type'] not in [
            'mmdet.FocalLoss', 'mmdet.GHMC'
        ]
zhangwenwei's avatar
zhangwenwei committed
107
108
        if not self.use_sigmoid_cls:
            self.num_classes += 1
109
110
111
        self.loss_cls = MODELS.build(loss_cls)
        self.loss_bbox = MODELS.build(loss_bbox)
        self.loss_dir = MODELS.build(loss_dir)
zhangwenwei's avatar
zhangwenwei committed
112
113
        self.fp16_enabled = False

zhangwenwei's avatar
zhangwenwei committed
114
115
116
        self._init_layers()
        self._init_assigner_sampler()

117
118
119
120
121
122
123
124
        if init_cfg is None:
            self.init_cfg = dict(
                type='Normal',
                layer='Conv2d',
                std=0.01,
                override=dict(
                    type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))

zhangwenwei's avatar
zhangwenwei committed
125
    def _init_assigner_sampler(self):
126
        """Initialize the target assigner and sampler of the head."""
zhangwenwei's avatar
zhangwenwei committed
127
128
129
130
        if self.train_cfg is None:
            return

        if self.sampling:
131
            self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
zhangwenwei's avatar
zhangwenwei committed
132
133
134
        else:
            self.bbox_sampler = PseudoSampler()
        if isinstance(self.train_cfg.assigner, dict):
135
            self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
zhangwenwei's avatar
zhangwenwei committed
136
137
        elif isinstance(self.train_cfg.assigner, list):
            self.bbox_assigner = [
138
                TASK_UTILS.build(res) for res in self.train_cfg.assigner
zhangwenwei's avatar
zhangwenwei committed
139
140
            ]

zhangwenwei's avatar
zhangwenwei committed
141
    def _init_layers(self):
142
        """Initialize neural network layers of the head."""
zhangwenwei's avatar
zhangwenwei committed
143
144
145
146
147
148
149
150
        self.cls_out_channels = self.num_anchors * self.num_classes
        self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
        self.conv_reg = nn.Conv2d(self.feat_channels,
                                  self.num_anchors * self.box_code_size, 1)
        if self.use_direction_classifier:
            self.conv_dir_cls = nn.Conv2d(self.feat_channels,
                                          self.num_anchors * 2, 1)

151
    def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
wuyuefeng's avatar
wuyuefeng committed
152
153
154
        """Forward function on a single-scale feature map.

        Args:
liyinhao's avatar
liyinhao committed
155
            x (torch.Tensor): Input features.
wuyuefeng's avatar
wuyuefeng committed
156
157

        Returns:
158
            tuple[torch.Tensor]: Contain score of each class, bbox
zhangwenwei's avatar
zhangwenwei committed
159
                regression and direction classification predictions.
wuyuefeng's avatar
wuyuefeng committed
160
        """
zhangwenwei's avatar
zhangwenwei committed
161
162
163
164
165
166
167
        cls_score = self.conv_cls(x)
        bbox_pred = self.conv_reg(x)
        dir_cls_preds = None
        if self.use_direction_classifier:
            dir_cls_preds = self.conv_dir_cls(x)
        return cls_score, bbox_pred, dir_cls_preds

168
    def forward(self, feats: List[Tensor]) -> Tuple[list]:
wuyuefeng's avatar
wuyuefeng committed
169
170
171
        """Forward pass.

        Args:
liyinhao's avatar
liyinhao committed
172
            feats (list[torch.Tensor]): Multi-level features, e.g.,
wuyuefeng's avatar
wuyuefeng committed
173
174
175
                features produced by FPN.

        Returns:
176
            tuple[list[torch.Tensor]]: Multi-level class score, bbox
wuyuefeng's avatar
wuyuefeng committed
177
178
                and direction predictions.
        """
zhangwenwei's avatar
zhangwenwei committed
179
180
        return multi_apply(self.forward_single, feats)

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    def forward_train(self,
                      feats: List[Tensor],
                      batch_data_samples: List[Det3DDataSample],
                      proposal_cfg: Optional[ConfigDict] = None,
                      **kwargs):
        """
        Args:
            feats (list[Tensor]): Features from FPN.
            batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
                contains the meta information of each sample and
                corresponding annotations.
            proposal_cfg (ConfigDict, optional): Test / postprocessing
                configuration, if None, test_cfg would be used.
                Defaults to None.

        Returns:
            tuple or Tensor: When `proposal_cfg` is None, the detector is a \
            normal one-stage detector, The return value is the losses.

            - losses: (dict[str, Tensor]): A dictionary of loss components.

            When the `proposal_cfg` is not None, the head is used as a
            `rpn_head`, the return value is a tuple contains:

            - losses: (dict[str, Tensor]): A dictionary of loss components.
            - results_list (list[:obj:`InstanceData`]): Detection
              results of each input after the post process.
              Each item usually contains following keys.Det3DDataSample

                - scores_3d (Tensor): Classification scores, has a shape
                  (num_instances, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                  contains a tensor with shape (num_instances, 7).
        """
        outs = self.forward(feats)

        batch_gt_instance_3d = []
        batch_gt_instances_ignore = []
        batch_input_metas = []
        for data_sample in batch_data_samples:
            batch_input_metas.append(data_sample.metainfo)
            batch_gt_instance_3d.append(data_sample.gt_instances_3d)
            if 'ignored_instances' in data_sample:
                batch_gt_instances_ignore.append(data_sample.ignored_instances)
            else:
                batch_gt_instances_ignore.append(None)

        loss_inputs = outs + (batch_gt_instance_3d, batch_input_metas)
        losses = self.loss(
            *loss_inputs, batch_gt_instances_ignore=batch_gt_instances_ignore)
        if proposal_cfg is None:
            return losses
        else:
            batch_img_metas = [
                data_sample.metainfo for data_sample in batch_data_samples
            ]
            results_list = self.get_results(
                *outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg)
            return losses, results_list

    def simple_test(self,
                    feats: Tuple[Tensor],
                    batch_input_metas: List[dict],
                    rescale: bool = False) -> List[InstanceData]:
        """Test function without test-time augmentation.

        Args:
            feats (tuple[torch.Tensor]): Multi-level features from the
                upstream network, each is a 4D-tensor.
            batch_input_metas (list[dict]): List of image information.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.

        Returns:
            list[:obj:`InstanceData`]: Detection results of each input
            after the post process.
            Each item usually contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instances, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
              (num_instances, ).
            - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
              contains a tensor with shape (num_instances, 7).
        """
        outs = self.forward(feats)
        results_list = self.get_results(
            *outs, input_metas=batch_input_metas, rescale=rescale)
        return results_list

    def aug_test(self,
                 aug_batch_feats,
                 aug_batch_input_metas,
                 rescale=False,
                 **kwargs):
        aug_bboxes = []
        # only support aug_test for one sample
        for x, input_meta in zip(aug_batch_feats, aug_batch_input_metas):
            outs = self.forward(x)
            bbox_list = self.get_results(*outs, [input_meta], rescale=rescale)
            bbox_dict = dict(
                bboxes_3d=bbox_list[0].bboxes_3d,
                scores_3d=bbox_list[0].scores_3d,
                labels_3d=bbox_list[0].labels_3d)
            aug_bboxes.append(bbox_dict)
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, aug_batch_input_metas,
                                            self.test_cfg)
        return [merged_bboxes]

    def get_anchors(self,
                    featmap_sizes: List[tuple],
                    input_metas: List[dict],
                    device: str = 'cuda') -> list:
zhangwenwei's avatar
zhangwenwei committed
297
        """Get anchors according to feature map sizes.
zhangwenwei's avatar
zhangwenwei committed
298

zhangwenwei's avatar
zhangwenwei committed
299
300
301
        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            input_metas (list[dict]): contain pcd and img's meta info.
wangtai's avatar
wangtai committed
302
            device (str): device of current module.
zhangwenwei's avatar
zhangwenwei committed
303

zhangwenwei's avatar
zhangwenwei committed
304
        Returns:
305
            list[list[torch.Tensor]]: Anchors of each image, valid flags
wangtai's avatar
wangtai committed
306
                of each image.
zhangwenwei's avatar
zhangwenwei committed
307
308
309
310
        """
        num_imgs = len(input_metas)
        # since feature map sizes of all images are the same, we only compute
        # anchors for one time
311
        multi_level_anchors = self.prior_generator.grid_anchors(
312
            featmap_sizes, device=device)
zhangwenwei's avatar
zhangwenwei committed
313
314
315
316
317
318
        anchor_list = [multi_level_anchors for _ in range(num_imgs)]
        return anchor_list

    def loss_single(self, cls_score, bbox_pred, dir_cls_preds, labels,
                    label_weights, bbox_targets, bbox_weights, dir_targets,
                    dir_weights, num_total_samples):
wuyuefeng's avatar
wuyuefeng committed
319
320
321
        """Calculate loss of Single-level results.

        Args:
liyinhao's avatar
liyinhao committed
322
323
324
            cls_score (torch.Tensor): Class score in single-level.
            bbox_pred (torch.Tensor): Bbox prediction in single-level.
            dir_cls_preds (torch.Tensor): Predictions of direction class
wuyuefeng's avatar
wuyuefeng committed
325
                in single-level.
liyinhao's avatar
liyinhao committed
326
327
328
329
330
331
            labels (torch.Tensor): Labels of class.
            label_weights (torch.Tensor): Weights of class loss.
            bbox_targets (torch.Tensor): Targets of bbox predictions.
            bbox_weights (torch.Tensor): Weights of bbox loss.
            dir_targets (torch.Tensor): Targets of direction predictions.
            dir_weights (torch.Tensor): Weights of direction loss.
wuyuefeng's avatar
wuyuefeng committed
332
333
334
            num_total_samples (int): The number of valid samples.

        Returns:
335
            tuple[torch.Tensor]: Losses of class, bbox
liyinhao's avatar
liyinhao committed
336
                and direction, respectively.
wuyuefeng's avatar
wuyuefeng committed
337
        """
zhangwenwei's avatar
zhangwenwei committed
338
339
340
341
342
343
        # classification loss
        if num_total_samples is None:
            num_total_samples = int(cls_score.shape[0])
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
344
        assert labels.max().item() <= self.num_classes
zhangwenwei's avatar
zhangwenwei committed
345
346
347
348
        loss_cls = self.loss_cls(
            cls_score, labels, label_weights, avg_factor=num_total_samples)

        # regression loss
349
350
        bbox_pred = bbox_pred.permute(0, 2, 3,
                                      1).reshape(-1, self.box_code_size)
zhangwenwei's avatar
zhangwenwei committed
351
352
353
        bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
        bbox_weights = bbox_weights.reshape(-1, self.box_code_size)

354
355
        bg_class_ind = self.num_classes
        pos_inds = ((labels >= 0)
Wenhao Wu's avatar
Wenhao Wu committed
356
357
                    & (labels < bg_class_ind)).nonzero(
                        as_tuple=False).reshape(-1)
358
359
360
361
362
363
364
        num_pos = len(pos_inds)

        pos_bbox_pred = bbox_pred[pos_inds]
        pos_bbox_targets = bbox_targets[pos_inds]
        pos_bbox_weights = bbox_weights[pos_inds]

        # dir loss
zhangwenwei's avatar
zhangwenwei committed
365
366
367
368
        if self.use_direction_classifier:
            dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2)
            dir_targets = dir_targets.reshape(-1)
            dir_weights = dir_weights.reshape(-1)
369
370
371
372
373
374
375
            pos_dir_cls_preds = dir_cls_preds[pos_inds]
            pos_dir_targets = dir_targets[pos_inds]
            pos_dir_weights = dir_weights[pos_inds]

        if num_pos > 0:
            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
376
                pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor(
377
378
379
380
381
382
383
384
                    code_weight)
            if self.diff_rad_by_sin:
                pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
                    pos_bbox_pred, pos_bbox_targets)
            loss_bbox = self.loss_bbox(
                pos_bbox_pred,
                pos_bbox_targets,
                pos_bbox_weights,
zhangwenwei's avatar
zhangwenwei committed
385
386
                avg_factor=num_total_samples)

387
388
389
390
391
392
393
394
395
396
397
398
399
            # direction classification loss
            loss_dir = None
            if self.use_direction_classifier:
                loss_dir = self.loss_dir(
                    pos_dir_cls_preds,
                    pos_dir_targets,
                    pos_dir_weights,
                    avg_factor=num_total_samples)
        else:
            loss_bbox = pos_bbox_pred.sum()
            if self.use_direction_classifier:
                loss_dir = pos_dir_cls_preds.sum()

zhangwenwei's avatar
zhangwenwei committed
400
401
402
        return loss_cls, loss_bbox, loss_dir

    @staticmethod
403
    def add_sin_difference(boxes1: Tensor, boxes2: Tensor) -> tuple:
zhangwenwei's avatar
zhangwenwei committed
404
        """Convert the rotation difference to difference in sine function.
zhangwenwei's avatar
zhangwenwei committed
405
406

        Args:
zhangwenwei's avatar
zhangwenwei committed
407
408
409
410
            boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
                and the 7th dimension is rotation dimension.
            boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
                the 7th dimension is rotation dimension.
zhangwenwei's avatar
zhangwenwei committed
411
412

        Returns:
413
            tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th
zhangwenwei's avatar
zhangwenwei committed
414
                dimensions are changed.
zhangwenwei's avatar
zhangwenwei committed
415
416
417
418
419
420
421
422
423
        """
        rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
            boxes2[..., 6:7])
        rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[...,
                                                                         6:7])
        boxes1 = torch.cat(
            [boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1)
        boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]],
                           dim=-1)
zhangwenwei's avatar
zhangwenwei committed
424
425
        return boxes1, boxes2

426
    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
zhangwenwei's avatar
zhangwenwei committed
427
    def loss(self,
428
429
430
431
432
433
             cls_scores: List[Tensor],
             bbox_preds: List[Tensor],
             dir_cls_preds: List[Tensor],
             batch_gt_instances_3d: List[InstanceData],
             batch_input_metas: List[dict],
             batch_gt_instances_ignore: List[InstanceData] = None) -> dict:
wuyuefeng's avatar
wuyuefeng committed
434
435
436
        """Calculate losses.

        Args:
liyinhao's avatar
liyinhao committed
437
438
439
            cls_scores (list[torch.Tensor]): Multi-level class scores.
            bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
            dir_cls_preds (list[torch.Tensor]): Multi-level direction
wuyuefeng's avatar
wuyuefeng committed
440
                class predictions.
441
442
443
444
445
446
447
448
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instances. It usually includes ``bboxes`` and ``labels``
                attributes.
            batch_input_metas (list[dict]): Contain pcd and img's meta info.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
wuyuefeng's avatar
wuyuefeng committed
449
450

        Returns:
451
            dict[str, list[torch.Tensor]]: Classification, bbox, and
zhangwenwei's avatar
zhangwenwei committed
452
                direction losses of each level.
453

454
455
                - loss_cls (list[torch.Tensor]): Classification losses.
                - loss_bbox (list[torch.Tensor]): Box regression losses.
456
                - loss_dir (list[torch.Tensor]): Direction classification
457
                    losses.
wuyuefeng's avatar
wuyuefeng committed
458
        """
zhangwenwei's avatar
zhangwenwei committed
459
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
460
        assert len(featmap_sizes) == self.prior_generator.num_levels
461
462
        device = cls_scores[0].device
        anchor_list = self.get_anchors(
463
            featmap_sizes, batch_input_metas, device=device)
zhangwenwei's avatar
zhangwenwei committed
464
465
466
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = self.anchor_target_3d(
            anchor_list,
467
468
469
            batch_gt_instances_3d,
            batch_input_metas,
            batch_gt_instances_ignore=batch_gt_instances_ignore,
zhangwenwei's avatar
zhangwenwei committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
            num_classes=self.num_classes,
            label_channels=label_channels,
            sampling=self.sampling)

        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         dir_targets_list, dir_weights_list, num_total_pos,
         num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)

        # num_total_samples = None
        losses_cls, losses_bbox, losses_dir = multi_apply(
            self.loss_single,
            cls_scores,
            bbox_preds,
            dir_cls_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            dir_targets_list,
            dir_weights_list,
            num_total_samples=num_total_samples)
        return dict(
zhangwenwei's avatar
zhangwenwei committed
496
            loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir)
zhangwenwei's avatar
zhangwenwei committed
497

498
499
500
501
502
503
504
505
    def get_results(self,
                    cls_scores: List[Tensor],
                    bbox_preds: List[Tensor],
                    dir_cls_preds: List[Tensor],
                    input_metas: List[dict],
                    cfg: ConfigDict = None,
                    rescale: list = False) -> List[InstanceData]:
        """Get results of anchor head.
wuyuefeng's avatar
wuyuefeng committed
506
507

        Args:
liyinhao's avatar
liyinhao committed
508
509
510
            cls_scores (list[torch.Tensor]): Multi-level class scores.
            bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
            dir_cls_preds (list[torch.Tensor]): Multi-level direction
wuyuefeng's avatar
wuyuefeng committed
511
512
                class predictions.
            input_metas (list[dict]): Contain pcd and img's meta info.
513
            cfg (:obj:`ConfigDict`): Training or testing config.
wangtai's avatar
wangtai committed
514
            rescale (list[torch.Tensor]): Whether th rescale bbox.
wuyuefeng's avatar
wuyuefeng committed
515
516

        Returns:
517
518
519
520
521
522
523
524
525
526
            list[:obj:`InstanceData`]: Instance prediction
            results of each sample after the post process.
            Each item usually contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                    contains a tensor with shape (num_instances, 7).
wuyuefeng's avatar
wuyuefeng committed
527
        """
zhangwenwei's avatar
zhangwenwei committed
528
529
530
        assert len(cls_scores) == len(bbox_preds)
        assert len(cls_scores) == len(dir_cls_preds)
        num_levels = len(cls_scores)
531
532
        featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
        device = cls_scores[0].device
533
        mlvl_anchors = self.prior_generator.grid_anchors(
534
            featmap_sizes, device=device)
zhangwenwei's avatar
zhangwenwei committed
535
        mlvl_anchors = [
536
            anchor.reshape(-1, self.box_code_size) for anchor in mlvl_anchors
zhangwenwei's avatar
zhangwenwei committed
537
        ]
538

zhangwenwei's avatar
zhangwenwei committed
539
540
541
542
543
544
545
546
547
548
549
550
551
        result_list = []
        for img_id in range(len(input_metas)):
            cls_score_list = [
                cls_scores[i][img_id].detach() for i in range(num_levels)
            ]
            bbox_pred_list = [
                bbox_preds[i][img_id].detach() for i in range(num_levels)
            ]
            dir_cls_pred_list = [
                dir_cls_preds[i][img_id].detach() for i in range(num_levels)
            ]

            input_meta = input_metas[img_id]
552
553
554
555
556
            proposals = self._get_results_single(cls_score_list,
                                                 bbox_pred_list,
                                                 dir_cls_pred_list,
                                                 mlvl_anchors, input_meta, cfg,
                                                 rescale)
zhangwenwei's avatar
zhangwenwei committed
557
558
559
            result_list.append(proposals)
        return result_list

560
561
562
563
564
565
566
567
568
    def _get_results_single(self,
                            cls_scores: Tensor,
                            bbox_preds: Tensor,
                            dir_cls_preds: Tensor,
                            mlvl_anchors: List[Tensor],
                            input_meta: List[dict],
                            cfg: ConfigDict = None,
                            rescale: bool = False) -> InstanceData:
        """Get results of single branch.
wuyuefeng's avatar
wuyuefeng committed
569
570

        Args:
liyinhao's avatar
liyinhao committed
571
572
573
574
575
            cls_scores (torch.Tensor): Class score in single batch.
            bbox_preds (torch.Tensor): Bbox prediction in single batch.
            dir_cls_preds (torch.Tensor): Predictions of direction class
                in single batch.
            mlvl_anchors (List[torch.Tensor]): Multi-level anchors
wuyuefeng's avatar
wuyuefeng committed
576
577
                in single batch.
            input_meta (list[dict]): Contain pcd and img's meta info.
578
            cfg (:obj:`ConfigDict`): Training or testing config.
liyinhao's avatar
liyinhao committed
579
            rescale (list[torch.Tensor]): whether th rescale bbox.
wuyuefeng's avatar
wuyuefeng committed
580
581

        Returns:
582
583
584
585
586
587
588
589
590
591
            :obj:`InstanceData`: Detection results of each sample
            after the post process.
            Each item usually contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                    contains a tensor with shape (num_instances, 7).
wuyuefeng's avatar
wuyuefeng committed
592
        """
zhangwenwei's avatar
zhangwenwei committed
593
        cfg = self.test_cfg if cfg is None else cfg
zhangwenwei's avatar
zhangwenwei committed
594
595
596
597
598
599
600
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_dir_scores = []
        for cls_score, bbox_pred, dir_cls_pred, anchors in zip(
                cls_scores, bbox_preds, dir_cls_preds, mlvl_anchors):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
zhangwenwei's avatar
zhangwenwei committed
601
602
603
            assert cls_score.size()[-2:] == dir_cls_pred.size()[-2:]
            dir_cls_pred = dir_cls_pred.permute(1, 2, 0).reshape(-1, 2)
            dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1]
zhangwenwei's avatar
zhangwenwei committed
604
605
606
607
608
609
610
611
612
613

            cls_score = cls_score.permute(1, 2,
                                          0).reshape(-1, self.num_classes)
            if self.use_sigmoid_cls:
                scores = cls_score.sigmoid()
            else:
                scores = cls_score.softmax(-1)
            bbox_pred = bbox_pred.permute(1, 2,
                                          0).reshape(-1, self.box_code_size)

zhangwenwei's avatar
zhangwenwei committed
614
615
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
zhangwenwei's avatar
zhangwenwei committed
616
617
618
                if self.use_sigmoid_cls:
                    max_scores, _ = scores.max(dim=1)
                else:
zhangwenwei's avatar
zhangwenwei committed
619
620
621
622
623
624
625
                    max_scores, _ = scores[:, :-1].max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                anchors = anchors[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                dir_cls_score = dir_cls_score[topk_inds]

626
            bboxes = self.bbox_coder.decode(anchors, bbox_pred)
zhangwenwei's avatar
zhangwenwei committed
627
628
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
zhangwenwei's avatar
zhangwenwei committed
629
            mlvl_dir_scores.append(dir_cls_score)
zhangwenwei's avatar
zhangwenwei committed
630
631

        mlvl_bboxes = torch.cat(mlvl_bboxes)
zhangwenwei's avatar
zhangwenwei committed
632
633
        mlvl_bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d'](
            mlvl_bboxes, box_dim=self.box_code_size).bev)
zhangwenwei's avatar
zhangwenwei committed
634
635
636
        mlvl_scores = torch.cat(mlvl_scores)
        mlvl_dir_scores = torch.cat(mlvl_dir_scores)

zhangwenwei's avatar
zhangwenwei committed
637
638
639
640
641
642
643
644
645
646
647
        if self.use_sigmoid_cls:
            # Add a dummy background class to the front when using sigmoid
            padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
            mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)

        score_thr = cfg.get('score_thr', 0)
        results = box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms,
                                       mlvl_scores, score_thr, cfg.max_num,
                                       cfg, mlvl_dir_scores)
        bboxes, scores, labels, dir_scores = results
        if bboxes.shape[0] > 0:
zhangwenwei's avatar
zhangwenwei committed
648
649
            dir_rot = limit_period(bboxes[..., 6] - self.dir_offset,
                                   self.dir_limit_offset, np.pi)
zhangwenwei's avatar
zhangwenwei committed
650
            bboxes[..., 6] = (
zhangwenwei's avatar
zhangwenwei committed
651
                dir_rot + self.dir_offset +
zhangwenwei's avatar
zhangwenwei committed
652
                np.pi * dir_scores.to(bboxes.dtype))
653
        bboxes = input_meta['box_type_3d'](bboxes, box_dim=self.box_code_size)
654
655
656
657
658
        results = InstanceData()
        results.bboxes_3d = bboxes
        results.scores_3d = scores
        results.labels_3d = labels
        return results