anchor3d_head.py 18 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3
from typing import List, Tuple
4

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np
import torch
7
from mmdet.models.utils import multi_apply
8
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
9
from torch import nn as nn
zhangwenwei's avatar
zhangwenwei committed
10

zhangshilong's avatar
zhangshilong committed
11
12
from mmdet3d.models.task_modules import PseudoSampler
from mmdet3d.models.test_time_augs import merge_aug_bboxes_3d
13
from mmdet3d.registry import MODELS, TASK_UTILS
zhangshilong's avatar
zhangshilong committed
14
15
from mmdet3d.utils.typing import (ConfigType, InstanceList, OptConfigType,
                                  OptInstanceList)
16
from .base_3d_dense_head import Base3DDenseHead
zhangwenwei's avatar
zhangwenwei committed
17
18
19
from .train_mixins import AnchorTrainMixin


20
@MODELS.register_module()
21
22
class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin):
    """Anchor-based head for SECOND/PointPillars/MVXNet/PartA2.
23

zhangwenwei's avatar
zhangwenwei committed
24
    Args:
zhangwenwei's avatar
zhangwenwei committed
25
        num_classes (int): Number of classes.
zhangwenwei's avatar
zhangwenwei committed
26
27
        in_channels (int): Number of channels in the input feature map.
        feat_channels (int): Number of channels of the feature map.
28
29
30
31
32
33
34
        use_direction_classifier (bool): Whether to add a direction classifier.
        anchor_generator(dict): Config dict of anchor generator.
        assigner_per_size (bool): Whether to do assignment for each separate
            anchor size.
        assign_per_class (bool): Whether to do assignment for each class.
        diff_rad_by_sin (bool): Whether to change the difference into sin
            difference for box regression loss.
wuyuefeng's avatar
wuyuefeng committed
35
        dir_offset (float | int): The offset of BEV rotation angles.
36
            (TODO: may be moved into box coder)
wuyuefeng's avatar
wuyuefeng committed
37
38
39
        dir_limit_offset (float | int): The limited range of BEV
            rotation angles. (TODO: may be moved into box coder)
        bbox_coder (dict): Config dict of box coders.
zhangwenwei's avatar
zhangwenwei committed
40
41
        loss_cls (dict): Config of classification loss.
        loss_bbox (dict): Config of localization loss.
42
        loss_dir (dict): Config of direction classifier loss.
43
44
45
        train_cfg (dict): Train configs.
        test_cfg (dict): Test configs.
        init_cfg (dict or list[dict], optional): Initialization config dict.
zhangwenwei's avatar
zhangwenwei committed
46
    """
zhangwenwei's avatar
zhangwenwei committed
47
48

    def __init__(self,
49
50
51
52
                 num_classes: int,
                 in_channels: int,
                 feat_channels: int = 256,
                 use_direction_classifier: bool = True,
53
                 anchor_generator: ConfigType = dict(
54
55
56
                     type='Anchor3DRangeGenerator',
                     range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
                     strides=[2],
57
                     sizes=[[3.9, 1.6, 1.56]],
58
59
60
                     rotations=[0, 1.57],
                     custom_values=[],
                     reshape_out=False),
61
62
63
64
65
                 assigner_per_size: bool = False,
                 assign_per_class: bool = False,
                 diff_rad_by_sin: bool = True,
                 dir_offset: float = -np.pi / 2,
                 dir_limit_offset: int = 0,
66
67
68
                 bbox_coder: ConfigType = dict(type='DeltaXYZWLHRBBoxCoder'),
                 loss_cls: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
zhangwenwei's avatar
zhangwenwei committed
69
70
                     use_sigmoid=True,
                     loss_weight=1.0),
71
72
73
74
75
76
77
78
79
                 loss_bbox: ConfigType = dict(
                     type='mmdet.SmoothL1Loss',
                     beta=1.0 / 9.0,
                     loss_weight=2.0),
                 loss_dir: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss', loss_weight=0.2),
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 init_cfg: OptConfigType = None) -> None:
80
        super().__init__(init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
81
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
82
        self.num_classes = num_classes
zhangwenwei's avatar
zhangwenwei committed
83
84
85
86
87
88
89
90
91
        self.feat_channels = feat_channels
        self.diff_rad_by_sin = diff_rad_by_sin
        self.use_direction_classifier = use_direction_classifier
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.assigner_per_size = assigner_per_size
        self.assign_per_class = assign_per_class
        self.dir_offset = dir_offset
        self.dir_limit_offset = dir_limit_offset
92
93
94
        warnings.warn(
            'dir_offset and dir_limit_offset will be depressed and be '
            'incorporated into box coder in the future')
95
        self.fp16_enabled = False
zhangwenwei's avatar
zhangwenwei committed
96
97

        # build anchor generator
98
        self.prior_generator = TASK_UTILS.build(anchor_generator)
zhangwenwei's avatar
zhangwenwei committed
99
        # In 3D detection, the anchor stride is connected with anchor size
100
        self.num_anchors = self.prior_generator.num_base_anchors
zhangwenwei's avatar
zhangwenwei committed
101
        # build box coder
102
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
zhangwenwei's avatar
zhangwenwei committed
103
        self.box_code_size = self.bbox_coder.code_size
zhangwenwei's avatar
zhangwenwei committed
104

zhangwenwei's avatar
zhangwenwei committed
105
        # build loss function
zhangwenwei's avatar
zhangwenwei committed
106
        self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
107
108
109
        self.sampling = loss_cls['type'] not in [
            'mmdet.FocalLoss', 'mmdet.GHMC'
        ]
zhangwenwei's avatar
zhangwenwei committed
110
111
        if not self.use_sigmoid_cls:
            self.num_classes += 1
112
113
114
        self.loss_cls = MODELS.build(loss_cls)
        self.loss_bbox = MODELS.build(loss_bbox)
        self.loss_dir = MODELS.build(loss_dir)
zhangwenwei's avatar
zhangwenwei committed
115
116
        self.fp16_enabled = False

zhangwenwei's avatar
zhangwenwei committed
117
118
119
        self._init_layers()
        self._init_assigner_sampler()

120
121
122
123
124
125
126
127
        if init_cfg is None:
            self.init_cfg = dict(
                type='Normal',
                layer='Conv2d',
                std=0.01,
                override=dict(
                    type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))

zhangwenwei's avatar
zhangwenwei committed
128
    def _init_assigner_sampler(self):
129
        """Initialize the target assigner and sampler of the head."""
zhangwenwei's avatar
zhangwenwei committed
130
131
132
133
        if self.train_cfg is None:
            return

        if self.sampling:
134
            self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
zhangwenwei's avatar
zhangwenwei committed
135
136
137
        else:
            self.bbox_sampler = PseudoSampler()
        if isinstance(self.train_cfg.assigner, dict):
138
            self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
zhangwenwei's avatar
zhangwenwei committed
139
140
        elif isinstance(self.train_cfg.assigner, list):
            self.bbox_assigner = [
141
                TASK_UTILS.build(res) for res in self.train_cfg.assigner
zhangwenwei's avatar
zhangwenwei committed
142
143
            ]

zhangwenwei's avatar
zhangwenwei committed
144
    def _init_layers(self):
145
        """Initialize neural network layers of the head."""
zhangwenwei's avatar
zhangwenwei committed
146
147
148
149
150
151
152
153
        self.cls_out_channels = self.num_anchors * self.num_classes
        self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
        self.conv_reg = nn.Conv2d(self.feat_channels,
                                  self.num_anchors * self.box_code_size, 1)
        if self.use_direction_classifier:
            self.conv_dir_cls = nn.Conv2d(self.feat_channels,
                                          self.num_anchors * 2, 1)

154
    def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
wuyuefeng's avatar
wuyuefeng committed
155
156
157
        """Forward function on a single-scale feature map.

        Args:
158
            x (Tensor): Features of a single scale level.
wuyuefeng's avatar
wuyuefeng committed
159
160

        Returns:
161
162
163
164
165
166
167
168
            tuple:
                cls_score (Tensor): Cls scores for a single scale level
                    the channels number is num_base_priors * num_classes.
                bbox_pred (Tensor): Box energies / deltas for a single scale
                    level, the channels number is num_base_priors * C.
                dir_cls_pred (Tensor | None): Direction classification
                    prediction for a single scale level, the channels
                    number is num_base_priors * 2.
wuyuefeng's avatar
wuyuefeng committed
169
        """
zhangwenwei's avatar
zhangwenwei committed
170
171
        cls_score = self.conv_cls(x)
        bbox_pred = self.conv_reg(x)
172
        dir_cls_pred = None
zhangwenwei's avatar
zhangwenwei committed
173
        if self.use_direction_classifier:
174
175
            dir_cls_pred = self.conv_dir_cls(x)
        return cls_score, bbox_pred, dir_cls_pred
zhangwenwei's avatar
zhangwenwei committed
176

177
    def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
wuyuefeng's avatar
wuyuefeng committed
178
179
180
        """Forward pass.

        Args:
181
182
            x (tuple[Tensor]): Features from the upstream network,
                each is a 4D-tensor.
183
184

        Returns:
185
186
187
188
189
190
191
192
193
194
195
196
            tuple: A tuple of classification scores, bbox and direction
                classification prediction.

                - cls_scores (list[Tensor]): Classification scores for all
                    scale levels, each is a 4D-tensor, the channels number
                    is num_base_priors * num_classes.
                - bbox_preds (list[Tensor]): Box energies / deltas for all
                    scale levels, each is a 4D-tensor, the channels number
                    is num_base_priors * C.
                - dir_cls_preds (list[Tensor|None]): Direction classification
                    predictions for all scale levels, each is a 4D-tensor,
                    the channels number is num_base_priors * 2.
197
        """
198
        return multi_apply(self.forward_single, x)
199

200
    # TODO: Support augmentation test
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    def aug_test(self,
                 aug_batch_feats,
                 aug_batch_input_metas,
                 rescale=False,
                 **kwargs):
        aug_bboxes = []
        # only support aug_test for one sample
        for x, input_meta in zip(aug_batch_feats, aug_batch_input_metas):
            outs = self.forward(x)
            bbox_list = self.get_results(*outs, [input_meta], rescale=rescale)
            bbox_dict = dict(
                bboxes_3d=bbox_list[0].bboxes_3d,
                scores_3d=bbox_list[0].scores_3d,
                labels_3d=bbox_list[0].labels_3d)
            aug_bboxes.append(bbox_dict)
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, aug_batch_input_metas,
                                            self.test_cfg)
        return [merged_bboxes]

    def get_anchors(self,
                    featmap_sizes: List[tuple],
                    input_metas: List[dict],
                    device: str = 'cuda') -> list:
zhangwenwei's avatar
zhangwenwei committed
225
        """Get anchors according to feature map sizes.
zhangwenwei's avatar
zhangwenwei committed
226

zhangwenwei's avatar
zhangwenwei committed
227
228
229
        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            input_metas (list[dict]): contain pcd and img's meta info.
wangtai's avatar
wangtai committed
230
            device (str): device of current module.
zhangwenwei's avatar
zhangwenwei committed
231

zhangwenwei's avatar
zhangwenwei committed
232
        Returns:
233
            list[list[torch.Tensor]]: Anchors of each image, valid flags
wangtai's avatar
wangtai committed
234
                of each image.
zhangwenwei's avatar
zhangwenwei committed
235
236
237
238
        """
        num_imgs = len(input_metas)
        # since feature map sizes of all images are the same, we only compute
        # anchors for one time
239
        multi_level_anchors = self.prior_generator.grid_anchors(
240
            featmap_sizes, device=device)
zhangwenwei's avatar
zhangwenwei committed
241
242
243
        anchor_list = [multi_level_anchors for _ in range(num_imgs)]
        return anchor_list

244
245
246
247
248
    def _loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
                             dir_cls_pred: Tensor, labels: Tensor,
                             label_weights: Tensor, bbox_targets: Tensor,
                             bbox_weights: Tensor, dir_targets: Tensor,
                             dir_weights: Tensor, num_total_samples: int):
wuyuefeng's avatar
wuyuefeng committed
249
250
251
        """Calculate loss of Single-level results.

        Args:
252
253
254
            cls_score (Tensor): Class score in single-level.
            bbox_pred (Tensor): Bbox prediction in single-level.
            dir_cls_pred (Tensor): Predictions of direction class
wuyuefeng's avatar
wuyuefeng committed
255
                in single-level.
256
257
258
259
260
261
            labels (Tensor): Labels of class.
            label_weights (Tensor): Weights of class loss.
            bbox_targets (Tensor): Targets of bbox predictions.
            bbox_weights (Tensor): Weights of bbox loss.
            dir_targets (Tensor): Targets of direction predictions.
            dir_weights (Tensor): Weights of direction loss.
wuyuefeng's avatar
wuyuefeng committed
262
263
264
            num_total_samples (int): The number of valid samples.

        Returns:
265
            tuple[torch.Tensor]: Losses of class, bbox
liyinhao's avatar
liyinhao committed
266
                and direction, respectively.
wuyuefeng's avatar
wuyuefeng committed
267
        """
zhangwenwei's avatar
zhangwenwei committed
268
269
270
271
272
273
        # classification loss
        if num_total_samples is None:
            num_total_samples = int(cls_score.shape[0])
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
274
        assert labels.max().item() <= self.num_classes
zhangwenwei's avatar
zhangwenwei committed
275
276
277
278
        loss_cls = self.loss_cls(
            cls_score, labels, label_weights, avg_factor=num_total_samples)

        # regression loss
279
280
        bbox_pred = bbox_pred.permute(0, 2, 3,
                                      1).reshape(-1, self.box_code_size)
zhangwenwei's avatar
zhangwenwei committed
281
282
283
        bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
        bbox_weights = bbox_weights.reshape(-1, self.box_code_size)

284
285
        bg_class_ind = self.num_classes
        pos_inds = ((labels >= 0)
Wenhao Wu's avatar
Wenhao Wu committed
286
287
                    & (labels < bg_class_ind)).nonzero(
                        as_tuple=False).reshape(-1)
288
289
290
291
292
293
294
        num_pos = len(pos_inds)

        pos_bbox_pred = bbox_pred[pos_inds]
        pos_bbox_targets = bbox_targets[pos_inds]
        pos_bbox_weights = bbox_weights[pos_inds]

        # dir loss
zhangwenwei's avatar
zhangwenwei committed
295
        if self.use_direction_classifier:
296
            dir_cls_pred = dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)
zhangwenwei's avatar
zhangwenwei committed
297
298
            dir_targets = dir_targets.reshape(-1)
            dir_weights = dir_weights.reshape(-1)
299
            pos_dir_cls_pred = dir_cls_pred[pos_inds]
300
301
302
303
304
305
            pos_dir_targets = dir_targets[pos_inds]
            pos_dir_weights = dir_weights[pos_inds]

        if num_pos > 0:
            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
306
                pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor(
307
308
309
310
311
312
313
314
                    code_weight)
            if self.diff_rad_by_sin:
                pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
                    pos_bbox_pred, pos_bbox_targets)
            loss_bbox = self.loss_bbox(
                pos_bbox_pred,
                pos_bbox_targets,
                pos_bbox_weights,
zhangwenwei's avatar
zhangwenwei committed
315
316
                avg_factor=num_total_samples)

317
318
319
320
            # direction classification loss
            loss_dir = None
            if self.use_direction_classifier:
                loss_dir = self.loss_dir(
321
                    pos_dir_cls_pred,
322
323
324
325
326
327
                    pos_dir_targets,
                    pos_dir_weights,
                    avg_factor=num_total_samples)
        else:
            loss_bbox = pos_bbox_pred.sum()
            if self.use_direction_classifier:
328
                loss_dir = pos_dir_cls_pred.sum()
329

zhangwenwei's avatar
zhangwenwei committed
330
331
332
        return loss_cls, loss_bbox, loss_dir

    @staticmethod
333
    def add_sin_difference(boxes1: Tensor, boxes2: Tensor) -> tuple:
zhangwenwei's avatar
zhangwenwei committed
334
        """Convert the rotation difference to difference in sine function.
zhangwenwei's avatar
zhangwenwei committed
335
336

        Args:
zhangwenwei's avatar
zhangwenwei committed
337
338
339
340
            boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
                and the 7th dimension is rotation dimension.
            boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
                the 7th dimension is rotation dimension.
zhangwenwei's avatar
zhangwenwei committed
341
342

        Returns:
343
            tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th
zhangwenwei's avatar
zhangwenwei committed
344
                dimensions are changed.
zhangwenwei's avatar
zhangwenwei committed
345
346
347
348
349
350
351
352
353
        """
        rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
            boxes2[..., 6:7])
        rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[...,
                                                                         6:7])
        boxes1 = torch.cat(
            [boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1)
        boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]],
                           dim=-1)
zhangwenwei's avatar
zhangwenwei committed
354
355
        return boxes1, boxes2

356
357
358
359
360
361
362
363
364
365
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            dir_cls_preds: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_input_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
        """Calculate the loss based on the features extracted by the detection
        head.
wuyuefeng's avatar
wuyuefeng committed
366
367

        Args:
liyinhao's avatar
liyinhao committed
368
369
370
            cls_scores (list[torch.Tensor]): Multi-level class scores.
            bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
            dir_cls_preds (list[torch.Tensor]): Multi-level direction
wuyuefeng's avatar
wuyuefeng committed
371
                class predictions.
372
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
373
374
                gt_instances. It usually includes ``bboxes_3d``
                and ``labels_3d`` attributes.
375
376
377
378
379
            batch_input_metas (list[dict]): Contain pcd and img's meta info.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
wuyuefeng's avatar
wuyuefeng committed
380
381

        Returns:
382
            dict[str, list[torch.Tensor]]: Classification, bbox, and
zhangwenwei's avatar
zhangwenwei committed
383
                direction losses of each level.
384

385
386
                - loss_cls (list[torch.Tensor]): Classification losses.
                - loss_bbox (list[torch.Tensor]): Box regression losses.
387
                - loss_dir (list[torch.Tensor]): Direction classification
388
                    losses.
wuyuefeng's avatar
wuyuefeng committed
389
        """
zhangwenwei's avatar
zhangwenwei committed
390
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
391
        assert len(featmap_sizes) == self.prior_generator.num_levels
392
393
        device = cls_scores[0].device
        anchor_list = self.get_anchors(
394
            featmap_sizes, batch_input_metas, device=device)
zhangwenwei's avatar
zhangwenwei committed
395
396
397
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = self.anchor_target_3d(
            anchor_list,
398
399
400
            batch_gt_instances_3d,
            batch_input_metas,
            batch_gt_instances_ignore=batch_gt_instances_ignore,
zhangwenwei's avatar
zhangwenwei committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
            num_classes=self.num_classes,
            label_channels=label_channels,
            sampling=self.sampling)

        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         dir_targets_list, dir_weights_list, num_total_pos,
         num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)

        # num_total_samples = None
        losses_cls, losses_bbox, losses_dir = multi_apply(
415
            self._loss_by_feat_single,
zhangwenwei's avatar
zhangwenwei committed
416
417
418
419
420
421
422
423
424
425
426
            cls_scores,
            bbox_preds,
            dir_cls_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            dir_targets_list,
            dir_weights_list,
            num_total_samples=num_total_samples)
        return dict(
zhangwenwei's avatar
zhangwenwei committed
427
            loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir)