smoke_mono3d_head.py 21.1 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
3
from typing import List, Optional, Tuple, Union

4
import torch
ZCMax's avatar
ZCMax committed
5
6
7
from mmcv.runner import force_fp32
from mmengine.config import ConfigDict
from torch import Tensor
8
9
from torch.nn import functional as F

10
from mmdet3d.registry import MODELS
11
12
13
14
15
16
17
18
19
from mmdet.core import multi_apply
from mmdet.core.bbox.builder import build_bbox_coder
from mmdet.models.utils import gaussian_radius, gen_gaussian_target
from mmdet.models.utils.gaussian_target import (get_local_maximum,
                                                get_topk_from_heatmap,
                                                transpose_and_gather_feat)
from .anchor_free_mono3d_head import AnchorFreeMono3DHead


20
@MODELS.register_module()
21
22
23
24
25
26
27
28
29
30
31
32
33
class SMOKEMono3DHead(AnchorFreeMono3DHead):
    r"""Anchor-free head used in `SMOKE <https://arxiv.org/abs/2002.10111>`_

    .. code-block:: none

                /-----> 3*3 conv -----> 1*1 conv -----> cls
        feature
                \-----> 3*3 conv -----> 1*1 conv -----> reg

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
34
        dim_channel (list[int]): indices of dimension offset preds in
35
            regression heatmap channels.
36
        ori_channel (list[int]): indices of orientation offset pred in
37
            regression heatmap channels.
ZCMax's avatar
ZCMax committed
38
        bbox_coder (dict): Bbox coder for encoding and decoding boxes.
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        loss_cls (dict, optional): Config of classification loss.
            Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0).
        loss_bbox (dict, optional): Config of localization loss.
            Default: loss_bbox=dict(type='L1Loss', loss_weight=10.0).
        loss_dir (dict, optional): Config of direction classification loss.
            In SMOKE, Default: None.
        loss_attr (dict, optional): Config of attribute classification loss.
            In SMOKE, Default: None.
        loss_centerness (dict): Config of centerness loss.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
        init_cfg (dict): Initialization config dict. Default: None.
    """  # noqa: E501

    def __init__(self,
ZCMax's avatar
ZCMax committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                 num_classes: int,
                 in_channels: int,
                 dim_channel: List[int],
                 ori_channel: List[int],
                 bbox_coder: dict,
                 loss_cls: dict = dict(
                     type='GaussionFocalLoss', loss_weight=1.0),
                 loss_bbox: dict = dict(type='L1Loss', loss_weight=0.1),
                 loss_dir: Optional[dict] = None,
                 loss_attr: Optional[dict] = None,
                 norm_cfg: dict = dict(
                     type='GN', num_groups=32, requires_grad=True),
                 init_cfg: Optional[Union[ConfigDict, dict]] = None,
                 **kwargs) -> None:
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        super().__init__(
            num_classes,
            in_channels,
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_dir=loss_dir,
            loss_attr=loss_attr,
            norm_cfg=norm_cfg,
            init_cfg=init_cfg,
            **kwargs)
        self.dim_channel = dim_channel
        self.ori_channel = ori_channel
        self.bbox_coder = build_bbox_coder(bbox_coder)

ZCMax's avatar
ZCMax committed
82
    def forward(self, feats: Tuple[Tensor]):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        """Forward features from the upstream network.

        Args:
            feats (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            tuple:
                cls_scores (list[Tensor]): Box scores for each scale level,
                    each is a 4D-tensor, the channel number is
                    num_points * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for each scale
                    level, each is a 4D-tensor, the channel number is
                    num_points * bbox_code_size.
        """
        return multi_apply(self.forward_single, feats)

ZCMax's avatar
ZCMax committed
100
    def forward_single(self, x: Tensor) -> Union[Tensor, Tensor]:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        """Forward features of a single scale level.

        Args:
            x (Tensor): Input feature map.

        Returns:
            tuple: Scores for each class, bbox of input feature maps.
        """
        cls_score, bbox_pred, dir_cls_pred, attr_pred, cls_feat, reg_feat = \
            super().forward_single(x)
        cls_score = cls_score.sigmoid()  # turn to 0-1
        cls_score = cls_score.clamp(min=1e-4, max=1 - 1e-4)
        # (N, C, H, W)
        offset_dims = bbox_pred[:, self.dim_channel, ...]
        bbox_pred[:, self.dim_channel, ...] = offset_dims.sigmoid() - 0.5
        # (N, C, H, W)
        vector_ori = bbox_pred[:, self.ori_channel, ...]
        bbox_pred[:, self.ori_channel, ...] = F.normalize(vector_ori)
        return cls_score, bbox_pred

ZCMax's avatar
ZCMax committed
121
122
123
124
125
126
    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def get_results(self,
                    cls_scores,
                    bbox_preds,
                    batch_img_metas,
                    rescale=None):
127
128
129
130
131
        """Generate bboxes from bbox head predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level.
            bbox_preds (list[Tensor]): Box regression for each scale.
ZCMax's avatar
ZCMax committed
132
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
133
134
135
136
137
138
139
140
141
                image size, scaling factor, etc.
            rescale (bool): If True, return boxes in original image space.

        Returns:
            list[tuple[:obj:`CameraInstance3DBoxes`, Tensor, Tensor, None]]:
                Each item in result_list is 4-tuple.
        """
        assert len(cls_scores) == len(bbox_preds) == 1
        cam2imgs = torch.stack([
ZCMax's avatar
ZCMax committed
142
143
            cls_scores[0].new_tensor(img_metas['cam2img'])
            for img_metas in batch_img_metas
144
145
        ])
        trans_mats = torch.stack([
ZCMax's avatar
ZCMax committed
146
147
            cls_scores[0].new_tensor(img_metas['trans_mat'])
            for img_metas in batch_img_metas
148
149
150
151
        ])
        batch_bboxes, batch_scores, batch_topk_labels = self.decode_heatmap(
            cls_scores[0],
            bbox_preds[0],
ZCMax's avatar
ZCMax committed
152
            batch_img_metas,
153
154
155
156
157
158
            cam2imgs=cam2imgs,
            trans_mats=trans_mats,
            topk=100,
            kernel=3)

        result_list = []
ZCMax's avatar
ZCMax committed
159
        for img_id in range(len(batch_img_metas)):
160
161
162
163
164
165
166
167
168
169

            bboxes = batch_bboxes[img_id]
            scores = batch_scores[img_id]
            labels = batch_topk_labels[img_id]

            keep_idx = scores > 0.25
            bboxes = bboxes[keep_idx]
            scores = scores[keep_idx]
            labels = labels[keep_idx]

ZCMax's avatar
ZCMax committed
170
            bboxes = batch_img_metas[img_id]['box_type_3d'](
171
172
173
174
175
176
177
178
179
                bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
            attrs = None
            result_list.append((bboxes, scores, labels, attrs))

        return result_list

    def decode_heatmap(self,
                       cls_score,
                       reg_pred,
ZCMax's avatar
ZCMax committed
180
                       batch_img_metas,
181
182
183
184
185
186
187
188
189
190
191
                       cam2imgs,
                       trans_mats,
                       topk=100,
                       kernel=3):
        """Transform outputs into detections raw bbox predictions.

        Args:
            class_score (Tensor): Center predict heatmap,
                shape (B, num_classes, H, W).
            reg_pred (Tensor): Box regression map.
                shape (B, channel, H , W).
ZCMax's avatar
ZCMax committed
192
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                image size, scaling factor, etc.
            cam2imgs (Tensor): Camera intrinsic matrixs.
                shape (B, 4, 4)
            trans_mats (Tensor): Transformation matrix from original image
                to feature map.
                shape: (batch, 3, 3)
            topk (int): Get top k center keypoints from heatmap. Default 100.
            kernel (int): Max pooling kernel for extract local maximum pixels.
               Default 3.

        Returns:
            tuple[torch.Tensor]: Decoded output of SMOKEHead, containing
               the following Tensors:
              - batch_bboxes (Tensor): Coords of each 3D box.
                    shape (B, k, 7)
              - batch_scores (Tensor): Scores of each 3D box.
                    shape (B, k)
              - batch_topk_labels (Tensor): Categories of each 3D box.
                    shape (B, k)
        """
ZCMax's avatar
ZCMax committed
213
        img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        bs, _, feat_h, feat_w = cls_score.shape

        center_heatmap_pred = get_local_maximum(cls_score, kernel=kernel)

        *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap(
            center_heatmap_pred, k=topk)
        batch_scores, batch_index, batch_topk_labels = batch_dets

        regression = transpose_and_gather_feat(reg_pred, batch_index)
        regression = regression.view(-1, 8)

        points = torch.cat([topk_xs.view(-1, 1),
                            topk_ys.view(-1, 1).float()],
                           dim=1)
        locations, dimensions, orientations = self.bbox_coder.decode(
            regression, points, batch_topk_labels, cam2imgs, trans_mats)

        batch_bboxes = torch.cat((locations, dimensions, orientations), dim=1)
        batch_bboxes = batch_bboxes.view(bs, -1, self.bbox_code_size)
        return batch_bboxes, batch_scores, batch_topk_labels

ZCMax's avatar
ZCMax committed
235
236
237
    def get_predictions(self, labels_3d, centers_2d, gt_locations,
                        gt_dimensions, gt_orientations, indices,
                        batch_img_metas, pred_reg):
238
239
240
        """Prepare predictions for computing loss.

        Args:
ZCMax's avatar
ZCMax committed
241
            labels_3d (Tensor): Labels of each 3D box.
242
                shape (B, max_objs, )
ZCMax's avatar
ZCMax committed
243
            centers_2d (Tensor): Coords of each projected 3D box
244
245
246
247
248
249
250
                center on image. shape (B * max_objs, 2)
            gt_locations (Tensor): Coords of each 3D box's location.
                shape (B * max_objs, 3)
            gt_dimensions (Tensor): Dimensions of each 3D box.
                shape (N, 3)
            gt_orientations (Tensor): Orientation(yaw) of each 3D box.
                shape (N, 1)
251
            indices (Tensor): Indices of the existence of the 3D box.
252
                shape (B * max_objs, )
ZCMax's avatar
ZCMax committed
253
254
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
255
256
257
258
259
260
261
262
            pre_reg (Tensor): Box regression map.
                shape (B, channel, H , W).

        Returns:
            dict: the dict has components below:
            - bbox3d_yaws (:obj:`CameraInstance3DBoxes`):
                bbox calculated using pred orientations.
            - bbox3d_dims (:obj:`CameraInstance3DBoxes`):
263
                bbox calculated using pred dimensions.
264
265
266
267
268
269
            - bbox3d_locs (:obj:`CameraInstance3DBoxes`):
                bbox calculated using pred locations.
        """
        batch, channel = pred_reg.shape[0], pred_reg.shape[1]
        w = pred_reg.shape[3]
        cam2imgs = torch.stack([
ZCMax's avatar
ZCMax committed
270
271
            gt_locations.new_tensor(img_metas['cam2img'])
            for img_metas in batch_img_metas
272
273
        ])
        trans_mats = torch.stack([
ZCMax's avatar
ZCMax committed
274
275
            gt_locations.new_tensor(img_metas['trans_mat'])
            for img_metas in batch_img_metas
276
        ])
ZCMax's avatar
ZCMax committed
277
278
279
        centers_2d_inds = centers_2d[:, 1] * w + centers_2d[:, 0]
        centers_2d_inds = centers_2d_inds.view(batch, -1)
        pred_regression = transpose_and_gather_feat(pred_reg, centers_2d_inds)
280
281
        pred_regression_pois = pred_regression.view(-1, channel)
        locations, dimensions, orientations = self.bbox_coder.decode(
ZCMax's avatar
ZCMax committed
282
            pred_regression_pois, centers_2d, labels_3d, cam2imgs, trans_mats,
283
284
            gt_locations)

285
286
        locations, dimensions, orientations = locations[indices], dimensions[
            indices], orientations[indices]
287
288
289

        locations[:, 1] += dimensions[:, 1] / 2

290
        gt_locations = gt_locations[indices]
291
292
293
294
295

        assert len(locations) == len(gt_locations)
        assert len(dimensions) == len(gt_dimensions)
        assert len(orientations) == len(gt_orientations)
        bbox3d_yaws = self.bbox_coder.encode(gt_locations, gt_dimensions,
ZCMax's avatar
ZCMax committed
296
                                             orientations, batch_img_metas)
297
        bbox3d_dims = self.bbox_coder.encode(gt_locations, dimensions,
ZCMax's avatar
ZCMax committed
298
                                             gt_orientations, batch_img_metas)
299
        bbox3d_locs = self.bbox_coder.encode(locations, gt_dimensions,
ZCMax's avatar
ZCMax committed
300
                                             gt_orientations, batch_img_metas)
301
302
303
304
305

        pred_bboxes = dict(ori=bbox3d_yaws, dim=bbox3d_dims, loc=bbox3d_locs)

        return pred_bboxes

ZCMax's avatar
ZCMax committed
306
    def get_targets(self, batch_gt_instances_3d, feat_shape, batch_img_metas):
307
        """Get training targets for batch images.
308

309
        Args:
ZCMax's avatar
ZCMax committed
310
311
312
313
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instance_3d.  It usually includes ``bboxes``、``labels``
                、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
                attributes.
314
315
            feat_shape (tuple[int]): Feature map shape with value,
                shape (B, _, H, W).
ZCMax's avatar
ZCMax committed
316
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
317
318
319
320
321
                image size, scaling factor, etc.

        Returns:
            tuple[Tensor, dict]: The Tensor value is the targets of
                center heatmap, the dict has components below:
ZCMax's avatar
ZCMax committed
322
              - gt_centers_2d (Tensor): Coords of each projected 3D box
323
                    center on image. shape (B * max_objs, 2)
ZCMax's avatar
ZCMax committed
324
              - gt_labels_3d (Tensor): Labels of each 3D box.
325
326
                    shape (B, max_objs, )
              - indices (Tensor): Indices of the existence of the 3D box.
327
                    shape (B * max_objs, )
328
              - affine_indices (Tensor): Indices of the affine of the 3D box.
329
330
331
332
333
334
335
336
337
338
339
                    shape (N, )
              - gt_locs (Tensor): Coords of each 3D box's location.
                    shape (N, 3)
              - gt_dims (Tensor): Dimensions of each 3D box.
                    shape (N, 3)
              - gt_yaws (Tensor): Orientation(yaw) of each 3D box.
                    shape (N, 1)
              - gt_cors (Tensor): Coords of the corners of each 3D box.
                    shape (N, 8, 3)
        """

ZCMax's avatar
ZCMax committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        gt_bboxes = [
            gt_instances_3d.bboxes for gt_instances_3d in batch_gt_instances_3d
        ]
        gt_labels = [
            gt_instances_3d.labels for gt_instances_3d in batch_gt_instances_3d
        ]
        gt_bboxes_3d = [
            gt_instances_3d.bboxes_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        gt_labels_3d = [
            gt_instances_3d.labels_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        centers_2d = [
            gt_instances_3d.centers_2d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        img_shape = batch_img_metas[0]['pad_shape']

360
361
        reg_mask = torch.stack([
            gt_bboxes[0].new_tensor(
ZCMax's avatar
ZCMax committed
362
363
                not img_metas['affine_aug'], dtype=torch.bool)
            for img_metas in batch_img_metas
364
365
366
367
368
369
370
371
372
373
374
375
376
        ])

        img_h, img_w = img_shape[:2]
        bs, _, feat_h, feat_w = feat_shape

        width_ratio = float(feat_w / img_w)  # 1/4
        height_ratio = float(feat_h / img_h)  # 1/4

        assert width_ratio == height_ratio

        center_heatmap_target = gt_bboxes[-1].new_zeros(
            [bs, self.num_classes, feat_h, feat_w])

ZCMax's avatar
ZCMax committed
377
        gt_centers_2d = centers_2d.copy()
378
379
380
381

        for batch_id in range(bs):
            gt_bbox = gt_bboxes[batch_id]
            gt_label = gt_labels[batch_id]
ZCMax's avatar
ZCMax committed
382
383
            # project centers_2d from input image to feat map
            gt_center_2d = gt_centers_2d[batch_id] * width_ratio
384

ZCMax's avatar
ZCMax committed
385
            for j, center in enumerate(gt_center_2d):
386
387
388
389
390
391
392
393
394
395
396
                center_x_int, center_y_int = center.int()
                scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
                scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
                radius = gaussian_radius([scale_box_h, scale_box_w],
                                         min_overlap=0.7)
                radius = max(0, int(radius))
                ind = gt_label[j]
                gen_gaussian_target(center_heatmap_target[batch_id, ind],
                                    [center_x_int, center_y_int], radius)

        avg_factor = max(1, center_heatmap_target.eq(1).sum())
ZCMax's avatar
ZCMax committed
397
        num_ctrs = [center_2d.shape[0] for center_2d in centers_2d]
398
399
400
401
402
403
        max_objs = max(num_ctrs)

        reg_inds = torch.cat(
            [reg_mask[i].repeat(num_ctrs[i]) for i in range(bs)])

        inds = torch.zeros((bs, max_objs),
ZCMax's avatar
ZCMax committed
404
                           dtype=torch.bool).to(centers_2d[0].device)
405
406
407

        # put gt 3d bboxes to gpu
        gt_bboxes_3d = [
ZCMax's avatar
ZCMax committed
408
            gt_bbox_3d.to(centers_2d[0].device) for gt_bbox_3d in gt_bboxes_3d
409
410
        ]

ZCMax's avatar
ZCMax committed
411
        batch_centers_2d = centers_2d[0].new_zeros((bs, max_objs, 2))
412
413
414
415
416
        batch_labels_3d = gt_labels_3d[0].new_zeros((bs, max_objs))
        batch_gt_locations = \
            gt_bboxes_3d[0].tensor.new_zeros((bs, max_objs, 3))
        for i in range(bs):
            inds[i, :num_ctrs[i]] = 1
ZCMax's avatar
ZCMax committed
417
            batch_centers_2d[i, :num_ctrs[i]] = centers_2d[i]
418
419
420
421
422
            batch_labels_3d[i, :num_ctrs[i]] = gt_labels_3d[i]
            batch_gt_locations[i, :num_ctrs[i]] = \
                gt_bboxes_3d[i].tensor[:, :3]

        inds = inds.flatten()
ZCMax's avatar
ZCMax committed
423
        batch_centers_2d = batch_centers_2d.view(-1, 2) * width_ratio
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        batch_gt_locations = batch_gt_locations.view(-1, 3)

        # filter the empty image, without gt_bboxes_3d
        gt_bboxes_3d = [
            gt_bbox_3d for gt_bbox_3d in gt_bboxes_3d
            if gt_bbox_3d.tensor.shape[0] > 0
        ]

        gt_dimensions = torch.cat(
            [gt_bbox_3d.tensor[:, 3:6] for gt_bbox_3d in gt_bboxes_3d])
        gt_orientations = torch.cat([
            gt_bbox_3d.tensor[:, 6].unsqueeze(-1)
            for gt_bbox_3d in gt_bboxes_3d
        ])
        gt_corners = torch.cat(
            [gt_bbox_3d.corners for gt_bbox_3d in gt_bboxes_3d])

        target_labels = dict(
ZCMax's avatar
ZCMax committed
442
443
            gt_centers_2d=batch_centers_2d.long(),
            gt_labels_3d=batch_labels_3d,
444
445
            indices=inds,
            reg_indices=reg_inds,
446
447
448
449
450
451
452
453
454
455
            gt_locs=batch_gt_locations,
            gt_dims=gt_dimensions,
            gt_yaws=gt_orientations,
            gt_cors=gt_corners)

        return center_heatmap_target, avg_factor, target_labels

    def loss(self,
             cls_scores,
             bbox_preds,
ZCMax's avatar
ZCMax committed
456
457
458
             batch_gt_instances_3d,
             batch_img_metas,
             batch_gt_instances_ignore=None):
459
460
461
462
463
464
465
466
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level.
                shape (num_gt, 4).
            bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
                number is bbox_code_size.
                shape (B, 7, H, W).
ZCMax's avatar
ZCMax committed
467
468
469
470
471
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instance_3d.  It usually includes ``bboxes``、``labels``
                、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
                attributes.
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
472
                image size, scaling factor, etc.
ZCMax's avatar
ZCMax committed
473
474
475
476
            batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
477
478
479
480
481

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == 1
ZCMax's avatar
ZCMax committed
482
483
        assert batch_gt_instances_ignore is None
        center_2d_heatmap = cls_scores[0]
484
485
        pred_reg = bbox_preds[0]

ZCMax's avatar
ZCMax committed
486
487
488
489
        center_2d_heatmap_target, avg_factor, target_labels = \
            self.get_targets(batch_gt_instances_3d,
                             center_2d_heatmap.shape,
                             batch_img_metas)
490
491

        pred_bboxes = self.get_predictions(
ZCMax's avatar
ZCMax committed
492
493
            labels_3d=target_labels['gt_labels_3d'],
            centers_2d=target_labels['gt_centers_2d'],
494
495
496
            gt_locations=target_labels['gt_locs'],
            gt_dimensions=target_labels['gt_dims'],
            gt_orientations=target_labels['gt_yaws'],
497
            indices=target_labels['indices'],
ZCMax's avatar
ZCMax committed
498
            batch_img_metas=batch_img_metas,
499
500
501
            pred_reg=pred_reg)

        loss_cls = self.loss_cls(
ZCMax's avatar
ZCMax committed
502
            center_2d_heatmap, center_2d_heatmap_target, avg_factor=avg_factor)
503

504
        reg_inds = target_labels['reg_indices']
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522

        loss_bbox_oris = self.loss_bbox(
            pred_bboxes['ori'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox_dims = self.loss_bbox(
            pred_bboxes['dim'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox_locs = self.loss_bbox(
            pred_bboxes['loc'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox = loss_bbox_dims + loss_bbox_locs + loss_bbox_oris

        loss_dict = dict(loss_cls=loss_cls, loss_bbox=loss_bbox)

        return loss_dict