smoke_mono3d_head.py 22.7 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
from typing import List, Optional, Tuple
ZCMax's avatar
ZCMax committed
3

4
import torch
5
6
7
8
9
from mmdet.models.utils import (gaussian_radius, gen_gaussian_target,
                                multi_apply)
from mmdet.models.utils.gaussian_target import (get_local_maximum,
                                                get_topk_from_heatmap,
                                                transpose_and_gather_feat)
10
from mmengine.structures import InstanceData
ZCMax's avatar
ZCMax committed
11
from torch import Tensor
12
13
from torch.nn import functional as F

ZCMax's avatar
ZCMax committed
14
from mmdet3d.registry import MODELS, TASK_UTILS
zhangshilong's avatar
zhangshilong committed
15
16
from mmdet3d.utils import (ConfigType, InstanceList, OptConfigType,
                           OptInstanceList, OptMultiConfig)
17
18
19
from .anchor_free_mono3d_head import AnchorFreeMono3DHead


20
@MODELS.register_module()
21
22
23
24
25
26
27
28
29
30
31
32
33
class SMOKEMono3DHead(AnchorFreeMono3DHead):
    r"""Anchor-free head used in `SMOKE <https://arxiv.org/abs/2002.10111>`_

    .. code-block:: none

                /-----> 3*3 conv -----> 1*1 conv -----> cls
        feature
                \-----> 3*3 conv -----> 1*1 conv -----> reg

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
34
        dim_channel (list[int]): indices of dimension offset preds in
35
            regression heatmap channels.
36
        ori_channel (list[int]): indices of orientation offset pred in
37
            regression heatmap channels.
ZCMax's avatar
ZCMax committed
38
39
40
        bbox_coder (:obj:`ConfigDict` or dict): Bbox coder for encoding
            and decoding boxes.
        loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
41
            Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0).
ZCMax's avatar
ZCMax committed
42
        loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
43
            Default: loss_bbox=dict(type='L1Loss', loss_weight=10.0).
ZCMax's avatar
ZCMax committed
44
45
46
47
48
        loss_dir (:obj:`ConfigDict` or dict, Optional): Config of direction
            classification loss. In SMOKE, Default: None.
        loss_attr (:obj:`ConfigDict` or dict, Optional): Config of attribute
            classification loss. In SMOKE, Default: None.
        norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and config norm layer.
49
            Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
ZCMax's avatar
ZCMax committed
50
51
        init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
            dict]): Initialization config dict. Defaults to None.
52
53
54
    """  # noqa: E501

    def __init__(self,
ZCMax's avatar
ZCMax committed
55
56
57
58
                 num_classes: int,
                 in_channels: int,
                 dim_channel: List[int],
                 ori_channel: List[int],
ZCMax's avatar
ZCMax committed
59
60
61
62
63
64
65
66
                 bbox_coder: ConfigType,
                 loss_cls: ConfigType = dict(
                     type='mmdet.GaussionFocalLoss', loss_weight=1.0),
                 loss_bbox: ConfigType = dict(
                     type='mmdet.L1Loss', loss_weight=0.1),
                 loss_dir: OptConfigType = None,
                 loss_attr: OptConfigType = None,
                 norm_cfg: OptConfigType = dict(
ZCMax's avatar
ZCMax committed
67
                     type='GN', num_groups=32, requires_grad=True),
ZCMax's avatar
ZCMax committed
68
                 init_cfg: OptMultiConfig = None,
ZCMax's avatar
ZCMax committed
69
                 **kwargs) -> None:
70
71
72
73
74
75
76
77
78
79
80
81
        super().__init__(
            num_classes,
            in_channels,
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_dir=loss_dir,
            loss_attr=loss_attr,
            norm_cfg=norm_cfg,
            init_cfg=init_cfg,
            **kwargs)
        self.dim_channel = dim_channel
        self.ori_channel = ori_channel
ZCMax's avatar
ZCMax committed
82
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
83

ZCMax's avatar
ZCMax committed
84
    def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
85
86
87
        """Forward features from the upstream network.

        Args:
ZCMax's avatar
ZCMax committed
88
            x (tuple[Tensor]): Features from the upstream network, each is
89
90
91
92
93
94
95
96
97
98
99
                a 4D-tensor.

        Returns:
            tuple:
                cls_scores (list[Tensor]): Box scores for each scale level,
                    each is a 4D-tensor, the channel number is
                    num_points * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for each scale
                    level, each is a 4D-tensor, the channel number is
                    num_points * bbox_code_size.
        """
ZCMax's avatar
ZCMax committed
100
        return multi_apply(self.forward_single, x)
101

ZCMax's avatar
ZCMax committed
102
    def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        """Forward features of a single scale level.

        Args:
            x (Tensor): Input feature map.

        Returns:
            tuple: Scores for each class, bbox of input feature maps.
        """
        cls_score, bbox_pred, dir_cls_pred, attr_pred, cls_feat, reg_feat = \
            super().forward_single(x)
        cls_score = cls_score.sigmoid()  # turn to 0-1
        cls_score = cls_score.clamp(min=1e-4, max=1 - 1e-4)
        # (N, C, H, W)
        offset_dims = bbox_pred[:, self.dim_channel, ...]
        bbox_pred[:, self.dim_channel, ...] = offset_dims.sigmoid() - 0.5
        # (N, C, H, W)
        vector_ori = bbox_pred[:, self.ori_channel, ...]
        bbox_pred[:, self.ori_channel, ...] = F.normalize(vector_ori)
        return cls_score, bbox_pred

ZCMax's avatar
ZCMax committed
123
124
125
126
127
    def predict_by_feat(self,
                        cls_scores: List[Tensor],
                        bbox_preds: List[Tensor],
                        batch_img_metas: Optional[List[dict]] = None,
                        rescale: bool = None) -> InstanceList:
128
129
130
131
132
        """Generate bboxes from bbox head predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level.
            bbox_preds (list[Tensor]): Box regression for each scale.
ZCMax's avatar
ZCMax committed
133
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
134
135
136
137
                image size, scaling factor, etc.
            rescale (bool): If True, return boxes in original image space.

        Returns:
ZCMax's avatar
ZCMax committed
138
139
140
141
142
143
144
145
146
147
            list[:obj:`InstanceData`]: 3D Detection results of each image
            after the post process.
            Each item usually contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
                (num_instance, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
                (num_instances, ).
            - bboxes_3d (Tensor): Contains a tensor with shape
                (num_instances, 7).
148
149
150
        """
        assert len(cls_scores) == len(bbox_preds) == 1
        cam2imgs = torch.stack([
ZCMax's avatar
ZCMax committed
151
152
            cls_scores[0].new_tensor(img_meta['cam2img'])
            for img_meta in batch_img_metas
153
154
        ])
        trans_mats = torch.stack([
ZCMax's avatar
ZCMax committed
155
156
            cls_scores[0].new_tensor(img_meta['trans_mat'])
            for img_meta in batch_img_metas
157
        ])
ZCMax's avatar
ZCMax committed
158
        batch_bboxes, batch_scores, batch_topk_labels = self._decode_heatmap(
159
160
            cls_scores[0],
            bbox_preds[0],
ZCMax's avatar
ZCMax committed
161
            batch_img_metas,
162
163
164
165
166
167
            cam2imgs=cam2imgs,
            trans_mats=trans_mats,
            topk=100,
            kernel=3)

        result_list = []
ZCMax's avatar
ZCMax committed
168
        for img_id in range(len(batch_img_metas)):
169
170
171
172
173
174
175
176
177
178

            bboxes = batch_bboxes[img_id]
            scores = batch_scores[img_id]
            labels = batch_topk_labels[img_id]

            keep_idx = scores > 0.25
            bboxes = bboxes[keep_idx]
            scores = scores[keep_idx]
            labels = labels[keep_idx]

ZCMax's avatar
ZCMax committed
179
            bboxes = batch_img_metas[img_id]['box_type_3d'](
180
181
                bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
            attrs = None
ZCMax's avatar
ZCMax committed
182
183
184
185
186
187
188
189
190
191

            results = InstanceData()
            results.bboxes_3d = bboxes
            results.labels_3d = labels
            results.scores_3d = scores

            if attrs is not None:
                results.attr_labels = attrs

            result_list.append(results)
192
193
194

        return result_list

ZCMax's avatar
ZCMax committed
195
196
197
198
199
200
201
202
    def _decode_heatmap(self,
                        cls_score: Tensor,
                        reg_pred: Tensor,
                        batch_img_metas: List[dict],
                        cam2imgs: Tensor,
                        trans_mats: Tensor,
                        topk: int = 100,
                        kernel: int = 3) -> Tuple[Tensor, Tensor, Tensor]:
203
204
205
206
207
208
209
        """Transform outputs into detections raw bbox predictions.

        Args:
            class_score (Tensor): Center predict heatmap,
                shape (B, num_classes, H, W).
            reg_pred (Tensor): Box regression map.
                shape (B, channel, H , W).
ZCMax's avatar
ZCMax committed
210
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
211
212
213
214
215
216
217
218
219
220
221
222
223
                image size, scaling factor, etc.
            cam2imgs (Tensor): Camera intrinsic matrixs.
                shape (B, 4, 4)
            trans_mats (Tensor): Transformation matrix from original image
                to feature map.
                shape: (batch, 3, 3)
            topk (int): Get top k center keypoints from heatmap. Default 100.
            kernel (int): Max pooling kernel for extract local maximum pixels.
               Default 3.

        Returns:
            tuple[torch.Tensor]: Decoded output of SMOKEHead, containing
               the following Tensors:
ZCMax's avatar
ZCMax committed
224

225
226
227
228
229
230
231
              - batch_bboxes (Tensor): Coords of each 3D box.
                    shape (B, k, 7)
              - batch_scores (Tensor): Scores of each 3D box.
                    shape (B, k)
              - batch_topk_labels (Tensor): Categories of each 3D box.
                    shape (B, k)
        """
ZCMax's avatar
ZCMax committed
232
        img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        bs, _, feat_h, feat_w = cls_score.shape

        center_heatmap_pred = get_local_maximum(cls_score, kernel=kernel)

        *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap(
            center_heatmap_pred, k=topk)
        batch_scores, batch_index, batch_topk_labels = batch_dets

        regression = transpose_and_gather_feat(reg_pred, batch_index)
        regression = regression.view(-1, 8)

        points = torch.cat([topk_xs.view(-1, 1),
                            topk_ys.view(-1, 1).float()],
                           dim=1)
        locations, dimensions, orientations = self.bbox_coder.decode(
            regression, points, batch_topk_labels, cam2imgs, trans_mats)

        batch_bboxes = torch.cat((locations, dimensions, orientations), dim=1)
        batch_bboxes = batch_bboxes.view(bs, -1, self.bbox_code_size)
        return batch_bboxes, batch_scores, batch_topk_labels

ZCMax's avatar
ZCMax committed
254
255
256
257
    def get_predictions(self, labels_3d: Tensor, centers_2d: Tensor,
                        gt_locations: Tensor, gt_dimensions: Tensor,
                        gt_orientations: Tensor, indices: Tensor,
                        batch_img_metas: List[dict], pred_reg: Tensor) -> dict:
258
259
260
        """Prepare predictions for computing loss.

        Args:
ZCMax's avatar
ZCMax committed
261
            labels_3d (Tensor): Labels of each 3D box.
262
                shape (B, max_objs, )
ZCMax's avatar
ZCMax committed
263
            centers_2d (Tensor): Coords of each projected 3D box
264
265
266
267
268
269
270
                center on image. shape (B * max_objs, 2)
            gt_locations (Tensor): Coords of each 3D box's location.
                shape (B * max_objs, 3)
            gt_dimensions (Tensor): Dimensions of each 3D box.
                shape (N, 3)
            gt_orientations (Tensor): Orientation(yaw) of each 3D box.
                shape (N, 1)
271
            indices (Tensor): Indices of the existence of the 3D box.
272
                shape (B * max_objs, )
ZCMax's avatar
ZCMax committed
273
274
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
275
276
277
278
279
            pre_reg (Tensor): Box regression map.
                shape (B, channel, H , W).

        Returns:
            dict: the dict has components below:
ZCMax's avatar
ZCMax committed
280

281
282
283
            - bbox3d_yaws (:obj:`CameraInstance3DBoxes`):
                bbox calculated using pred orientations.
            - bbox3d_dims (:obj:`CameraInstance3DBoxes`):
284
                bbox calculated using pred dimensions.
285
286
287
288
289
290
            - bbox3d_locs (:obj:`CameraInstance3DBoxes`):
                bbox calculated using pred locations.
        """
        batch, channel = pred_reg.shape[0], pred_reg.shape[1]
        w = pred_reg.shape[3]
        cam2imgs = torch.stack([
ZCMax's avatar
ZCMax committed
291
292
            gt_locations.new_tensor(img_meta['cam2img'])
            for img_meta in batch_img_metas
293
294
        ])
        trans_mats = torch.stack([
ZCMax's avatar
ZCMax committed
295
296
            gt_locations.new_tensor(img_meta['trans_mat'])
            for img_meta in batch_img_metas
297
        ])
ZCMax's avatar
ZCMax committed
298
299
300
        centers_2d_inds = centers_2d[:, 1] * w + centers_2d[:, 0]
        centers_2d_inds = centers_2d_inds.view(batch, -1)
        pred_regression = transpose_and_gather_feat(pred_reg, centers_2d_inds)
301
302
        pred_regression_pois = pred_regression.view(-1, channel)
        locations, dimensions, orientations = self.bbox_coder.decode(
ZCMax's avatar
ZCMax committed
303
            pred_regression_pois, centers_2d, labels_3d, cam2imgs, trans_mats,
304
305
            gt_locations)

306
307
        locations, dimensions, orientations = locations[indices], dimensions[
            indices], orientations[indices]
308
309
310

        locations[:, 1] += dimensions[:, 1] / 2

311
        gt_locations = gt_locations[indices]
312
313
314
315
316

        assert len(locations) == len(gt_locations)
        assert len(dimensions) == len(gt_dimensions)
        assert len(orientations) == len(gt_orientations)
        bbox3d_yaws = self.bbox_coder.encode(gt_locations, gt_dimensions,
ZCMax's avatar
ZCMax committed
317
                                             orientations, batch_img_metas)
318
        bbox3d_dims = self.bbox_coder.encode(gt_locations, dimensions,
ZCMax's avatar
ZCMax committed
319
                                             gt_orientations, batch_img_metas)
320
        bbox3d_locs = self.bbox_coder.encode(locations, gt_dimensions,
ZCMax's avatar
ZCMax committed
321
                                             gt_orientations, batch_img_metas)
322
323
324
325
326

        pred_bboxes = dict(ori=bbox3d_yaws, dim=bbox3d_dims, loc=bbox3d_locs)

        return pred_bboxes

ZCMax's avatar
ZCMax committed
327
328
329
    def get_targets(self, batch_gt_instances_3d: InstanceList,
                    batch_gt_instances: InstanceList, feat_shape: Tuple[int],
                    batch_img_metas: List[dict]) -> Tuple[Tensor, int, dict]:
330
        """Get training targets for batch images.
331

332
        Args:
ZCMax's avatar
ZCMax committed
333
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
ZCMax's avatar
ZCMax committed
334
335
336
337
                gt_instance_3d.  It usually includes ``bboxes_3d``、
                ``labels_3d``、``depths``、``centers_2d`` and attributes.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance.  It usually includes ``bboxes``、``labels``.
338
339
            feat_shape (tuple[int]): Feature map shape with value,
                shape (B, _, H, W).
ZCMax's avatar
ZCMax committed
340
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
341
342
343
                image size, scaling factor, etc.

        Returns:
ZCMax's avatar
ZCMax committed
344
            tuple[Tensor, int, dict]: The Tensor value is the targets of
345
                center heatmap, the dict has components below:
ZCMax's avatar
ZCMax committed
346

ZCMax's avatar
ZCMax committed
347
              - gt_centers_2d (Tensor): Coords of each projected 3D box
348
                    center on image. shape (B * max_objs, 2)
ZCMax's avatar
ZCMax committed
349
              - gt_labels_3d (Tensor): Labels of each 3D box.
350
351
                    shape (B, max_objs, )
              - indices (Tensor): Indices of the existence of the 3D box.
352
                    shape (B * max_objs, )
353
              - affine_indices (Tensor): Indices of the affine of the 3D box.
354
355
356
357
358
359
360
361
362
363
364
                    shape (N, )
              - gt_locs (Tensor): Coords of each 3D box's location.
                    shape (N, 3)
              - gt_dims (Tensor): Dimensions of each 3D box.
                    shape (N, 3)
              - gt_yaws (Tensor): Orientation(yaw) of each 3D box.
                    shape (N, 1)
              - gt_cors (Tensor): Coords of the corners of each 3D box.
                    shape (N, 8, 3)
        """

ZCMax's avatar
ZCMax committed
365
        gt_bboxes = [
ZCMax's avatar
ZCMax committed
366
            gt_instances.bboxes for gt_instances in batch_gt_instances
ZCMax's avatar
ZCMax committed
367
368
        ]
        gt_labels = [
ZCMax's avatar
ZCMax committed
369
            gt_instances.labels for gt_instances in batch_gt_instances
ZCMax's avatar
ZCMax committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        ]
        gt_bboxes_3d = [
            gt_instances_3d.bboxes_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        gt_labels_3d = [
            gt_instances_3d.labels_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        centers_2d = [
            gt_instances_3d.centers_2d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        img_shape = batch_img_metas[0]['pad_shape']

385
386
        reg_mask = torch.stack([
            gt_bboxes[0].new_tensor(
ZCMax's avatar
ZCMax committed
387
388
                not img_meta['affine_aug'], dtype=torch.bool)
            for img_meta in batch_img_metas
389
390
391
392
393
394
395
396
397
398
399
400
401
        ])

        img_h, img_w = img_shape[:2]
        bs, _, feat_h, feat_w = feat_shape

        width_ratio = float(feat_w / img_w)  # 1/4
        height_ratio = float(feat_h / img_h)  # 1/4

        assert width_ratio == height_ratio

        center_heatmap_target = gt_bboxes[-1].new_zeros(
            [bs, self.num_classes, feat_h, feat_w])

ZCMax's avatar
ZCMax committed
402
        gt_centers_2d = centers_2d.copy()
403
404
405
406

        for batch_id in range(bs):
            gt_bbox = gt_bboxes[batch_id]
            gt_label = gt_labels[batch_id]
ZCMax's avatar
ZCMax committed
407
408
            # project centers_2d from input image to feat map
            gt_center_2d = gt_centers_2d[batch_id] * width_ratio
409

ZCMax's avatar
ZCMax committed
410
            for j, center in enumerate(gt_center_2d):
411
412
413
414
415
416
417
418
419
420
421
                center_x_int, center_y_int = center.int()
                scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
                scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
                radius = gaussian_radius([scale_box_h, scale_box_w],
                                         min_overlap=0.7)
                radius = max(0, int(radius))
                ind = gt_label[j]
                gen_gaussian_target(center_heatmap_target[batch_id, ind],
                                    [center_x_int, center_y_int], radius)

        avg_factor = max(1, center_heatmap_target.eq(1).sum())
ZCMax's avatar
ZCMax committed
422
        num_ctrs = [center_2d.shape[0] for center_2d in centers_2d]
423
424
425
426
427
428
        max_objs = max(num_ctrs)

        reg_inds = torch.cat(
            [reg_mask[i].repeat(num_ctrs[i]) for i in range(bs)])

        inds = torch.zeros((bs, max_objs),
ZCMax's avatar
ZCMax committed
429
                           dtype=torch.bool).to(centers_2d[0].device)
430
431
432

        # put gt 3d bboxes to gpu
        gt_bboxes_3d = [
ZCMax's avatar
ZCMax committed
433
            gt_bbox_3d.to(centers_2d[0].device) for gt_bbox_3d in gt_bboxes_3d
434
435
        ]

ZCMax's avatar
ZCMax committed
436
        batch_centers_2d = centers_2d[0].new_zeros((bs, max_objs, 2))
437
438
439
440
441
        batch_labels_3d = gt_labels_3d[0].new_zeros((bs, max_objs))
        batch_gt_locations = \
            gt_bboxes_3d[0].tensor.new_zeros((bs, max_objs, 3))
        for i in range(bs):
            inds[i, :num_ctrs[i]] = 1
ZCMax's avatar
ZCMax committed
442
            batch_centers_2d[i, :num_ctrs[i]] = centers_2d[i]
443
444
445
446
447
            batch_labels_3d[i, :num_ctrs[i]] = gt_labels_3d[i]
            batch_gt_locations[i, :num_ctrs[i]] = \
                gt_bboxes_3d[i].tensor[:, :3]

        inds = inds.flatten()
ZCMax's avatar
ZCMax committed
448
        batch_centers_2d = batch_centers_2d.view(-1, 2) * width_ratio
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        batch_gt_locations = batch_gt_locations.view(-1, 3)

        # filter the empty image, without gt_bboxes_3d
        gt_bboxes_3d = [
            gt_bbox_3d for gt_bbox_3d in gt_bboxes_3d
            if gt_bbox_3d.tensor.shape[0] > 0
        ]

        gt_dimensions = torch.cat(
            [gt_bbox_3d.tensor[:, 3:6] for gt_bbox_3d in gt_bboxes_3d])
        gt_orientations = torch.cat([
            gt_bbox_3d.tensor[:, 6].unsqueeze(-1)
            for gt_bbox_3d in gt_bboxes_3d
        ])
        gt_corners = torch.cat(
            [gt_bbox_3d.corners for gt_bbox_3d in gt_bboxes_3d])

        target_labels = dict(
ZCMax's avatar
ZCMax committed
467
468
            gt_centers_2d=batch_centers_2d.long(),
            gt_labels_3d=batch_labels_3d,
469
470
            indices=inds,
            reg_indices=reg_inds,
471
472
473
474
475
476
477
            gt_locs=batch_gt_locations,
            gt_dims=gt_dimensions,
            gt_yaws=gt_orientations,
            gt_cors=gt_corners)

        return center_heatmap_target, avg_factor, target_labels

ZCMax's avatar
ZCMax committed
478
479
480
481
482
483
484
485
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_gt_instances: InstanceList,
            batch_img_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
486
487
488
489
490
491
492
493
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level.
                shape (num_gt, 4).
            bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
                number is bbox_code_size.
                shape (B, 7, H, W).
ZCMax's avatar
ZCMax committed
494
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
ZCMax's avatar
ZCMax committed
495
496
497
498
                gt_instance_3d.  It usually includes ``bboxes_3d``、
                ``labels_3d``、``depths``、``centers_2d`` and attributes.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance.  It usually includes ``bboxes``、``labels``.
ZCMax's avatar
ZCMax committed
499
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
500
                image size, scaling factor, etc.
ZCMax's avatar
ZCMax committed
501
502
503
504
            batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
505
506

        Returns:
ZCMax's avatar
ZCMax committed
507
508
509
            dict[str, Tensor]: A dictionary of loss components, which has
                components below:

Tai-Wang's avatar
Tai-Wang committed
510
511
            - loss_cls (Tensor): loss of cls heatmap.
            - loss_bbox (Tensor): loss of bbox heatmap.
512
513
        """
        assert len(cls_scores) == len(bbox_preds) == 1
ZCMax's avatar
ZCMax committed
514
        center_2d_heatmap = cls_scores[0]
515
516
        pred_reg = bbox_preds[0]

ZCMax's avatar
ZCMax committed
517
518
        center_2d_heatmap_target, avg_factor, target_labels = \
            self.get_targets(batch_gt_instances_3d,
ZCMax's avatar
ZCMax committed
519
                             batch_gt_instances,
ZCMax's avatar
ZCMax committed
520
521
                             center_2d_heatmap.shape,
                             batch_img_metas)
522
523

        pred_bboxes = self.get_predictions(
ZCMax's avatar
ZCMax committed
524
525
            labels_3d=target_labels['gt_labels_3d'],
            centers_2d=target_labels['gt_centers_2d'],
526
527
528
            gt_locations=target_labels['gt_locs'],
            gt_dimensions=target_labels['gt_dims'],
            gt_orientations=target_labels['gt_yaws'],
529
            indices=target_labels['indices'],
ZCMax's avatar
ZCMax committed
530
            batch_img_metas=batch_img_metas,
531
532
533
            pred_reg=pred_reg)

        loss_cls = self.loss_cls(
ZCMax's avatar
ZCMax committed
534
            center_2d_heatmap, center_2d_heatmap_target, avg_factor=avg_factor)
535

536
        reg_inds = target_labels['reg_indices']
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

        loss_bbox_oris = self.loss_bbox(
            pred_bboxes['ori'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox_dims = self.loss_bbox(
            pred_bboxes['dim'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox_locs = self.loss_bbox(
            pred_bboxes['loc'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox = loss_bbox_dims + loss_bbox_locs + loss_bbox_oris

        loss_dict = dict(loss_cls=loss_cls, loss_bbox=loss_bbox)

        return loss_dict