anchor3d_head.py 18.3 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3
from typing import List, Tuple
4

zhangwenwei's avatar
zhangwenwei committed
5
6
import numpy as np
import torch
7
from mmdet.models.utils import multi_apply
Sun Jiahao's avatar
Sun Jiahao committed
8
9
from mmdet.utils.memory import cast_tensor_type
from mmengine.runner import amp
10
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
11
from torch import nn as nn
zhangwenwei's avatar
zhangwenwei committed
12

zhangshilong's avatar
zhangshilong committed
13
14
from mmdet3d.models.task_modules import PseudoSampler
from mmdet3d.models.test_time_augs import merge_aug_bboxes_3d
15
from mmdet3d.registry import MODELS, TASK_UTILS
16
17
from mmdet3d.utils.typing_utils import (ConfigType, InstanceList,
                                        OptConfigType, OptInstanceList)
18
from .base_3d_dense_head import Base3DDenseHead
zhangwenwei's avatar
zhangwenwei committed
19
20
21
from .train_mixins import AnchorTrainMixin


22
@MODELS.register_module()
23
24
class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin):
    """Anchor-based head for SECOND/PointPillars/MVXNet/PartA2.
25

zhangwenwei's avatar
zhangwenwei committed
26
    Args:
zhangwenwei's avatar
zhangwenwei committed
27
        num_classes (int): Number of classes.
zhangwenwei's avatar
zhangwenwei committed
28
29
        in_channels (int): Number of channels in the input feature map.
        feat_channels (int): Number of channels of the feature map.
30
31
32
33
34
35
36
        use_direction_classifier (bool): Whether to add a direction classifier.
        anchor_generator(dict): Config dict of anchor generator.
        assigner_per_size (bool): Whether to do assignment for each separate
            anchor size.
        assign_per_class (bool): Whether to do assignment for each class.
        diff_rad_by_sin (bool): Whether to change the difference into sin
            difference for box regression loss.
wuyuefeng's avatar
wuyuefeng committed
37
        dir_offset (float | int): The offset of BEV rotation angles.
38
            (TODO: may be moved into box coder)
wuyuefeng's avatar
wuyuefeng committed
39
40
41
        dir_limit_offset (float | int): The limited range of BEV
            rotation angles. (TODO: may be moved into box coder)
        bbox_coder (dict): Config dict of box coders.
zhangwenwei's avatar
zhangwenwei committed
42
43
        loss_cls (dict): Config of classification loss.
        loss_bbox (dict): Config of localization loss.
44
        loss_dir (dict): Config of direction classifier loss.
45
46
47
        train_cfg (dict): Train configs.
        test_cfg (dict): Test configs.
        init_cfg (dict or list[dict], optional): Initialization config dict.
zhangwenwei's avatar
zhangwenwei committed
48
    """
zhangwenwei's avatar
zhangwenwei committed
49
50

    def __init__(self,
51
52
53
54
                 num_classes: int,
                 in_channels: int,
                 feat_channels: int = 256,
                 use_direction_classifier: bool = True,
55
                 anchor_generator: ConfigType = dict(
56
57
58
                     type='Anchor3DRangeGenerator',
                     range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
                     strides=[2],
59
                     sizes=[[3.9, 1.6, 1.56]],
60
61
62
                     rotations=[0, 1.57],
                     custom_values=[],
                     reshape_out=False),
63
64
65
66
67
                 assigner_per_size: bool = False,
                 assign_per_class: bool = False,
                 diff_rad_by_sin: bool = True,
                 dir_offset: float = -np.pi / 2,
                 dir_limit_offset: int = 0,
68
69
70
                 bbox_coder: ConfigType = dict(type='DeltaXYZWLHRBBoxCoder'),
                 loss_cls: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
zhangwenwei's avatar
zhangwenwei committed
71
72
                     use_sigmoid=True,
                     loss_weight=1.0),
73
74
75
76
77
78
79
80
81
                 loss_bbox: ConfigType = dict(
                     type='mmdet.SmoothL1Loss',
                     beta=1.0 / 9.0,
                     loss_weight=2.0),
                 loss_dir: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss', loss_weight=0.2),
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 init_cfg: OptConfigType = None) -> None:
82
        super().__init__(init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
83
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
84
        self.num_classes = num_classes
zhangwenwei's avatar
zhangwenwei committed
85
86
87
88
89
90
91
92
93
        self.feat_channels = feat_channels
        self.diff_rad_by_sin = diff_rad_by_sin
        self.use_direction_classifier = use_direction_classifier
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.assigner_per_size = assigner_per_size
        self.assign_per_class = assign_per_class
        self.dir_offset = dir_offset
        self.dir_limit_offset = dir_limit_offset
94
95
96
        warnings.warn(
            'dir_offset and dir_limit_offset will be depressed and be '
            'incorporated into box coder in the future')
zhangwenwei's avatar
zhangwenwei committed
97
98

        # build anchor generator
99
        self.prior_generator = TASK_UTILS.build(anchor_generator)
zhangwenwei's avatar
zhangwenwei committed
100
        # In 3D detection, the anchor stride is connected with anchor size
101
        self.num_anchors = self.prior_generator.num_base_anchors
zhangwenwei's avatar
zhangwenwei committed
102
        # build box coder
103
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
zhangwenwei's avatar
zhangwenwei committed
104
        self.box_code_size = self.bbox_coder.code_size
zhangwenwei's avatar
zhangwenwei committed
105

zhangwenwei's avatar
zhangwenwei committed
106
        # build loss function
zhangwenwei's avatar
zhangwenwei committed
107
        self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
108
109
110
        self.sampling = loss_cls['type'] not in [
            'mmdet.FocalLoss', 'mmdet.GHMC'
        ]
zhangwenwei's avatar
zhangwenwei committed
111
112
        if not self.use_sigmoid_cls:
            self.num_classes += 1
113
114
115
        self.loss_cls = MODELS.build(loss_cls)
        self.loss_bbox = MODELS.build(loss_bbox)
        self.loss_dir = MODELS.build(loss_dir)
zhangwenwei's avatar
zhangwenwei committed
116

zhangwenwei's avatar
zhangwenwei committed
117
118
119
        self._init_layers()
        self._init_assigner_sampler()

120
121
122
123
124
125
126
127
        if init_cfg is None:
            self.init_cfg = dict(
                type='Normal',
                layer='Conv2d',
                std=0.01,
                override=dict(
                    type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))

zhangwenwei's avatar
zhangwenwei committed
128
    def _init_assigner_sampler(self):
129
        """Initialize the target assigner and sampler of the head."""
zhangwenwei's avatar
zhangwenwei committed
130
131
132
133
        if self.train_cfg is None:
            return

        if self.sampling:
134
            self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
zhangwenwei's avatar
zhangwenwei committed
135
136
137
        else:
            self.bbox_sampler = PseudoSampler()
        if isinstance(self.train_cfg.assigner, dict):
138
            self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
zhangwenwei's avatar
zhangwenwei committed
139
140
        elif isinstance(self.train_cfg.assigner, list):
            self.bbox_assigner = [
141
                TASK_UTILS.build(res) for res in self.train_cfg.assigner
zhangwenwei's avatar
zhangwenwei committed
142
143
            ]

zhangwenwei's avatar
zhangwenwei committed
144
    def _init_layers(self):
145
        """Initialize neural network layers of the head."""
zhangwenwei's avatar
zhangwenwei committed
146
147
148
149
150
151
152
153
        self.cls_out_channels = self.num_anchors * self.num_classes
        self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
        self.conv_reg = nn.Conv2d(self.feat_channels,
                                  self.num_anchors * self.box_code_size, 1)
        if self.use_direction_classifier:
            self.conv_dir_cls = nn.Conv2d(self.feat_channels,
                                          self.num_anchors * 2, 1)

154
    def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
wuyuefeng's avatar
wuyuefeng committed
155
156
157
        """Forward function on a single-scale feature map.

        Args:
158
            x (Tensor): Features of a single scale level.
wuyuefeng's avatar
wuyuefeng committed
159
160

        Returns:
161
162
163
164
165
166
167
168
            tuple:
                cls_score (Tensor): Cls scores for a single scale level
                    the channels number is num_base_priors * num_classes.
                bbox_pred (Tensor): Box energies / deltas for a single scale
                    level, the channels number is num_base_priors * C.
                dir_cls_pred (Tensor | None): Direction classification
                    prediction for a single scale level, the channels
                    number is num_base_priors * 2.
wuyuefeng's avatar
wuyuefeng committed
169
        """
zhangwenwei's avatar
zhangwenwei committed
170
171
        cls_score = self.conv_cls(x)
        bbox_pred = self.conv_reg(x)
172
        dir_cls_pred = None
zhangwenwei's avatar
zhangwenwei committed
173
        if self.use_direction_classifier:
174
175
            dir_cls_pred = self.conv_dir_cls(x)
        return cls_score, bbox_pred, dir_cls_pred
zhangwenwei's avatar
zhangwenwei committed
176

177
    def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
wuyuefeng's avatar
wuyuefeng committed
178
179
180
        """Forward pass.

        Args:
181
182
            x (tuple[Tensor]): Features from the upstream network,
                each is a 4D-tensor.
183
184

        Returns:
185
186
187
188
189
190
191
192
193
194
195
196
            tuple: A tuple of classification scores, bbox and direction
                classification prediction.

                - cls_scores (list[Tensor]): Classification scores for all
                    scale levels, each is a 4D-tensor, the channels number
                    is num_base_priors * num_classes.
                - bbox_preds (list[Tensor]): Box energies / deltas for all
                    scale levels, each is a 4D-tensor, the channels number
                    is num_base_priors * C.
                - dir_cls_preds (list[Tensor|None]): Direction classification
                    predictions for all scale levels, each is a 4D-tensor,
                    the channels number is num_base_priors * 2.
197
        """
198
        return multi_apply(self.forward_single, x)
199

200
    # TODO: Support augmentation test
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    def aug_test(self,
                 aug_batch_feats,
                 aug_batch_input_metas,
                 rescale=False,
                 **kwargs):
        aug_bboxes = []
        # only support aug_test for one sample
        for x, input_meta in zip(aug_batch_feats, aug_batch_input_metas):
            outs = self.forward(x)
            bbox_list = self.get_results(*outs, [input_meta], rescale=rescale)
            bbox_dict = dict(
                bboxes_3d=bbox_list[0].bboxes_3d,
                scores_3d=bbox_list[0].scores_3d,
                labels_3d=bbox_list[0].labels_3d)
            aug_bboxes.append(bbox_dict)
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, aug_batch_input_metas,
                                            self.test_cfg)
        return [merged_bboxes]

    def get_anchors(self,
                    featmap_sizes: List[tuple],
                    input_metas: List[dict],
                    device: str = 'cuda') -> list:
zhangwenwei's avatar
zhangwenwei committed
225
        """Get anchors according to feature map sizes.
zhangwenwei's avatar
zhangwenwei committed
226

zhangwenwei's avatar
zhangwenwei committed
227
228
229
        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            input_metas (list[dict]): contain pcd and img's meta info.
wangtai's avatar
wangtai committed
230
            device (str): device of current module.
zhangwenwei's avatar
zhangwenwei committed
231

zhangwenwei's avatar
zhangwenwei committed
232
        Returns:
233
            list[list[torch.Tensor]]: Anchors of each image, valid flags
wangtai's avatar
wangtai committed
234
                of each image.
zhangwenwei's avatar
zhangwenwei committed
235
236
237
238
        """
        num_imgs = len(input_metas)
        # since feature map sizes of all images are the same, we only compute
        # anchors for one time
239
        multi_level_anchors = self.prior_generator.grid_anchors(
240
            featmap_sizes, device=device)
zhangwenwei's avatar
zhangwenwei committed
241
242
243
        anchor_list = [multi_level_anchors for _ in range(num_imgs)]
        return anchor_list

244
245
246
247
248
    def _loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
                             dir_cls_pred: Tensor, labels: Tensor,
                             label_weights: Tensor, bbox_targets: Tensor,
                             bbox_weights: Tensor, dir_targets: Tensor,
                             dir_weights: Tensor, num_total_samples: int):
wuyuefeng's avatar
wuyuefeng committed
249
250
251
        """Calculate loss of Single-level results.

        Args:
252
253
254
            cls_score (Tensor): Class score in single-level.
            bbox_pred (Tensor): Bbox prediction in single-level.
            dir_cls_pred (Tensor): Predictions of direction class
wuyuefeng's avatar
wuyuefeng committed
255
                in single-level.
256
257
258
259
260
261
            labels (Tensor): Labels of class.
            label_weights (Tensor): Weights of class loss.
            bbox_targets (Tensor): Targets of bbox predictions.
            bbox_weights (Tensor): Weights of bbox loss.
            dir_targets (Tensor): Targets of direction predictions.
            dir_weights (Tensor): Weights of direction loss.
wuyuefeng's avatar
wuyuefeng committed
262
263
264
            num_total_samples (int): The number of valid samples.

        Returns:
265
            tuple[torch.Tensor]: Losses of class, bbox
liyinhao's avatar
liyinhao committed
266
                and direction, respectively.
wuyuefeng's avatar
wuyuefeng committed
267
        """
zhangwenwei's avatar
zhangwenwei committed
268
269
270
271
272
273
        # classification loss
        if num_total_samples is None:
            num_total_samples = int(cls_score.shape[0])
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
274
        assert labels.max().item() <= self.num_classes
zhangwenwei's avatar
zhangwenwei committed
275
276
277
278
        loss_cls = self.loss_cls(
            cls_score, labels, label_weights, avg_factor=num_total_samples)

        # regression loss
279
280
        bbox_pred = bbox_pred.permute(0, 2, 3,
                                      1).reshape(-1, self.box_code_size)
zhangwenwei's avatar
zhangwenwei committed
281
282
283
        bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
        bbox_weights = bbox_weights.reshape(-1, self.box_code_size)

284
285
        bg_class_ind = self.num_classes
        pos_inds = ((labels >= 0)
Wenhao Wu's avatar
Wenhao Wu committed
286
287
                    & (labels < bg_class_ind)).nonzero(
                        as_tuple=False).reshape(-1)
288
289
290
291
292
293
294
        num_pos = len(pos_inds)

        pos_bbox_pred = bbox_pred[pos_inds]
        pos_bbox_targets = bbox_targets[pos_inds]
        pos_bbox_weights = bbox_weights[pos_inds]

        # dir loss
zhangwenwei's avatar
zhangwenwei committed
295
        if self.use_direction_classifier:
296
            dir_cls_pred = dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)
zhangwenwei's avatar
zhangwenwei committed
297
298
            dir_targets = dir_targets.reshape(-1)
            dir_weights = dir_weights.reshape(-1)
299
            pos_dir_cls_pred = dir_cls_pred[pos_inds]
300
301
302
303
304
305
            pos_dir_targets = dir_targets[pos_inds]
            pos_dir_weights = dir_weights[pos_inds]

        if num_pos > 0:
            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
306
                pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor(
307
308
309
310
311
312
313
314
                    code_weight)
            if self.diff_rad_by_sin:
                pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
                    pos_bbox_pred, pos_bbox_targets)
            loss_bbox = self.loss_bbox(
                pos_bbox_pred,
                pos_bbox_targets,
                pos_bbox_weights,
zhangwenwei's avatar
zhangwenwei committed
315
316
                avg_factor=num_total_samples)

317
318
319
320
            # direction classification loss
            loss_dir = None
            if self.use_direction_classifier:
                loss_dir = self.loss_dir(
321
                    pos_dir_cls_pred,
322
323
324
325
326
327
                    pos_dir_targets,
                    pos_dir_weights,
                    avg_factor=num_total_samples)
        else:
            loss_bbox = pos_bbox_pred.sum()
            if self.use_direction_classifier:
328
                loss_dir = pos_dir_cls_pred.sum()
329

zhangwenwei's avatar
zhangwenwei committed
330
331
332
        return loss_cls, loss_bbox, loss_dir

    @staticmethod
333
    def add_sin_difference(boxes1: Tensor, boxes2: Tensor) -> tuple:
zhangwenwei's avatar
zhangwenwei committed
334
        """Convert the rotation difference to difference in sine function.
zhangwenwei's avatar
zhangwenwei committed
335
336

        Args:
zhangwenwei's avatar
zhangwenwei committed
337
338
339
340
            boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
                and the 7th dimension is rotation dimension.
            boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
                the 7th dimension is rotation dimension.
zhangwenwei's avatar
zhangwenwei committed
341
342

        Returns:
343
            tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th
zhangwenwei's avatar
zhangwenwei committed
344
                dimensions are changed.
zhangwenwei's avatar
zhangwenwei committed
345
346
347
348
349
350
351
352
353
        """
        rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
            boxes2[..., 6:7])
        rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[...,
                                                                         6:7])
        boxes1 = torch.cat(
            [boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1)
        boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]],
                           dim=-1)
zhangwenwei's avatar
zhangwenwei committed
354
355
        return boxes1, boxes2

356
357
358
359
360
361
362
363
364
365
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            dir_cls_preds: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_input_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
        """Calculate the loss based on the features extracted by the detection
        head.
wuyuefeng's avatar
wuyuefeng committed
366
367

        Args:
liyinhao's avatar
liyinhao committed
368
369
370
            cls_scores (list[torch.Tensor]): Multi-level class scores.
            bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
            dir_cls_preds (list[torch.Tensor]): Multi-level direction
wuyuefeng's avatar
wuyuefeng committed
371
                class predictions.
372
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
373
374
                gt_instances. It usually includes ``bboxes_3d``
                and ``labels_3d`` attributes.
375
376
377
378
379
            batch_input_metas (list[dict]): Contain pcd and img's meta info.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
wuyuefeng's avatar
wuyuefeng committed
380
381

        Returns:
382
            dict[str, list[torch.Tensor]]: Classification, bbox, and
zhangwenwei's avatar
zhangwenwei committed
383
                direction losses of each level.
384

385
386
                - loss_cls (list[torch.Tensor]): Classification losses.
                - loss_bbox (list[torch.Tensor]): Box regression losses.
387
                - loss_dir (list[torch.Tensor]): Direction classification
388
                    losses.
wuyuefeng's avatar
wuyuefeng committed
389
        """
zhangwenwei's avatar
zhangwenwei committed
390
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
391
        assert len(featmap_sizes) == self.prior_generator.num_levels
392
393
        device = cls_scores[0].device
        anchor_list = self.get_anchors(
394
            featmap_sizes, batch_input_metas, device=device)
zhangwenwei's avatar
zhangwenwei committed
395
396
397
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = self.anchor_target_3d(
            anchor_list,
398
399
400
            batch_gt_instances_3d,
            batch_input_metas,
            batch_gt_instances_ignore=batch_gt_instances_ignore,
zhangwenwei's avatar
zhangwenwei committed
401
402
403
404
405
406
407
408
409
410
411
412
413
            num_classes=self.num_classes,
            label_channels=label_channels,
            sampling=self.sampling)

        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         dir_targets_list, dir_weights_list, num_total_pos,
         num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)

        # num_total_samples = None
Sun Jiahao's avatar
Sun Jiahao committed
414
415
416
417
418
419
420
421
422
423
424
425
426
        with amp.autocast(enabled=False):
            losses_cls, losses_bbox, losses_dir = multi_apply(
                self._loss_by_feat_single,
                cast_tensor_type(cls_scores, dst_type=torch.float32),
                cast_tensor_type(bbox_preds, dst_type=torch.float32),
                cast_tensor_type(dir_cls_preds, dst_type=torch.float32),
                labels_list,
                label_weights_list,
                bbox_targets_list,
                bbox_weights_list,
                dir_targets_list,
                dir_weights_list,
                num_total_samples=num_total_samples)
zhangwenwei's avatar
zhangwenwei committed
427
        return dict(
zhangwenwei's avatar
zhangwenwei committed
428
            loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir)