smoke_mono3d_head.py 21.3 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
3
from typing import List, Optional, Tuple, Union

4
import torch
ZCMax's avatar
ZCMax committed
5
6
from mmcv.runner import force_fp32
from mmengine.config import ConfigDict
ZCMax's avatar
ZCMax committed
7
from mmengine.data import InstanceData
ZCMax's avatar
ZCMax committed
8
from torch import Tensor
9
10
from torch.nn import functional as F

ZCMax's avatar
ZCMax committed
11
from mmdet3d.registry import MODELS, TASK_UTILS
12
13
14
15
16
17
18
19
from mmdet.core import multi_apply
from mmdet.models.utils import gaussian_radius, gen_gaussian_target
from mmdet.models.utils.gaussian_target import (get_local_maximum,
                                                get_topk_from_heatmap,
                                                transpose_and_gather_feat)
from .anchor_free_mono3d_head import AnchorFreeMono3DHead


20
@MODELS.register_module()
21
22
23
24
25
26
27
28
29
30
31
32
33
class SMOKEMono3DHead(AnchorFreeMono3DHead):
    r"""Anchor-free head used in `SMOKE <https://arxiv.org/abs/2002.10111>`_

    .. code-block:: none

                /-----> 3*3 conv -----> 1*1 conv -----> cls
        feature
                \-----> 3*3 conv -----> 1*1 conv -----> reg

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
34
        dim_channel (list[int]): indices of dimension offset preds in
35
            regression heatmap channels.
36
        ori_channel (list[int]): indices of orientation offset pred in
37
            regression heatmap channels.
ZCMax's avatar
ZCMax committed
38
        bbox_coder (dict): Bbox coder for encoding and decoding boxes.
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        loss_cls (dict, optional): Config of classification loss.
            Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0).
        loss_bbox (dict, optional): Config of localization loss.
            Default: loss_bbox=dict(type='L1Loss', loss_weight=10.0).
        loss_dir (dict, optional): Config of direction classification loss.
            In SMOKE, Default: None.
        loss_attr (dict, optional): Config of attribute classification loss.
            In SMOKE, Default: None.
        loss_centerness (dict): Config of centerness loss.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
        init_cfg (dict): Initialization config dict. Default: None.
    """  # noqa: E501

    def __init__(self,
ZCMax's avatar
ZCMax committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                 num_classes: int,
                 in_channels: int,
                 dim_channel: List[int],
                 ori_channel: List[int],
                 bbox_coder: dict,
                 loss_cls: dict = dict(
                     type='GaussionFocalLoss', loss_weight=1.0),
                 loss_bbox: dict = dict(type='L1Loss', loss_weight=0.1),
                 loss_dir: Optional[dict] = None,
                 loss_attr: Optional[dict] = None,
                 norm_cfg: dict = dict(
                     type='GN', num_groups=32, requires_grad=True),
                 init_cfg: Optional[Union[ConfigDict, dict]] = None,
                 **kwargs) -> None:
68
69
70
71
72
73
74
75
76
77
78
79
        super().__init__(
            num_classes,
            in_channels,
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_dir=loss_dir,
            loss_attr=loss_attr,
            norm_cfg=norm_cfg,
            init_cfg=init_cfg,
            **kwargs)
        self.dim_channel = dim_channel
        self.ori_channel = ori_channel
ZCMax's avatar
ZCMax committed
80
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
81

ZCMax's avatar
ZCMax committed
82
    def forward(self, feats: Tuple[Tensor]):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        """Forward features from the upstream network.

        Args:
            feats (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            tuple:
                cls_scores (list[Tensor]): Box scores for each scale level,
                    each is a 4D-tensor, the channel number is
                    num_points * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for each scale
                    level, each is a 4D-tensor, the channel number is
                    num_points * bbox_code_size.
        """
        return multi_apply(self.forward_single, feats)

ZCMax's avatar
ZCMax committed
100
    def forward_single(self, x: Tensor) -> Union[Tensor, Tensor]:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        """Forward features of a single scale level.

        Args:
            x (Tensor): Input feature map.

        Returns:
            tuple: Scores for each class, bbox of input feature maps.
        """
        cls_score, bbox_pred, dir_cls_pred, attr_pred, cls_feat, reg_feat = \
            super().forward_single(x)
        cls_score = cls_score.sigmoid()  # turn to 0-1
        cls_score = cls_score.clamp(min=1e-4, max=1 - 1e-4)
        # (N, C, H, W)
        offset_dims = bbox_pred[:, self.dim_channel, ...]
        bbox_pred[:, self.dim_channel, ...] = offset_dims.sigmoid() - 0.5
        # (N, C, H, W)
        vector_ori = bbox_pred[:, self.ori_channel, ...]
        bbox_pred[:, self.ori_channel, ...] = F.normalize(vector_ori)
        return cls_score, bbox_pred

ZCMax's avatar
ZCMax committed
121
122
123
124
125
126
    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def get_results(self,
                    cls_scores,
                    bbox_preds,
                    batch_img_metas,
                    rescale=None):
127
128
129
130
131
        """Generate bboxes from bbox head predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level.
            bbox_preds (list[Tensor]): Box regression for each scale.
ZCMax's avatar
ZCMax committed
132
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
133
134
135
136
137
138
139
140
141
                image size, scaling factor, etc.
            rescale (bool): If True, return boxes in original image space.

        Returns:
            list[tuple[:obj:`CameraInstance3DBoxes`, Tensor, Tensor, None]]:
                Each item in result_list is 4-tuple.
        """
        assert len(cls_scores) == len(bbox_preds) == 1
        cam2imgs = torch.stack([
ZCMax's avatar
ZCMax committed
142
143
            cls_scores[0].new_tensor(img_meta['cam2img'])
            for img_meta in batch_img_metas
144
145
        ])
        trans_mats = torch.stack([
ZCMax's avatar
ZCMax committed
146
147
            cls_scores[0].new_tensor(img_meta['trans_mat'])
            for img_meta in batch_img_metas
148
149
150
151
        ])
        batch_bboxes, batch_scores, batch_topk_labels = self.decode_heatmap(
            cls_scores[0],
            bbox_preds[0],
ZCMax's avatar
ZCMax committed
152
            batch_img_metas,
153
154
155
156
157
158
            cam2imgs=cam2imgs,
            trans_mats=trans_mats,
            topk=100,
            kernel=3)

        result_list = []
ZCMax's avatar
ZCMax committed
159
        for img_id in range(len(batch_img_metas)):
160
161
162
163
164
165
166
167
168
169

            bboxes = batch_bboxes[img_id]
            scores = batch_scores[img_id]
            labels = batch_topk_labels[img_id]

            keep_idx = scores > 0.25
            bboxes = bboxes[keep_idx]
            scores = scores[keep_idx]
            labels = labels[keep_idx]

ZCMax's avatar
ZCMax committed
170
            bboxes = batch_img_metas[img_id]['box_type_3d'](
171
172
                bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
            attrs = None
ZCMax's avatar
ZCMax committed
173
174
175
176
177
178
179
180
181
182

            results = InstanceData()
            results.bboxes_3d = bboxes
            results.labels_3d = labels
            results.scores_3d = scores

            if attrs is not None:
                results.attr_labels = attrs

            result_list.append(results)
183
184
185
186
187
188

        return result_list

    def decode_heatmap(self,
                       cls_score,
                       reg_pred,
ZCMax's avatar
ZCMax committed
189
                       batch_img_metas,
190
191
192
193
194
195
196
197
198
199
200
                       cam2imgs,
                       trans_mats,
                       topk=100,
                       kernel=3):
        """Transform outputs into detections raw bbox predictions.

        Args:
            class_score (Tensor): Center predict heatmap,
                shape (B, num_classes, H, W).
            reg_pred (Tensor): Box regression map.
                shape (B, channel, H , W).
ZCMax's avatar
ZCMax committed
201
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                image size, scaling factor, etc.
            cam2imgs (Tensor): Camera intrinsic matrixs.
                shape (B, 4, 4)
            trans_mats (Tensor): Transformation matrix from original image
                to feature map.
                shape: (batch, 3, 3)
            topk (int): Get top k center keypoints from heatmap. Default 100.
            kernel (int): Max pooling kernel for extract local maximum pixels.
               Default 3.

        Returns:
            tuple[torch.Tensor]: Decoded output of SMOKEHead, containing
               the following Tensors:
              - batch_bboxes (Tensor): Coords of each 3D box.
                    shape (B, k, 7)
              - batch_scores (Tensor): Scores of each 3D box.
                    shape (B, k)
              - batch_topk_labels (Tensor): Categories of each 3D box.
                    shape (B, k)
        """
ZCMax's avatar
ZCMax committed
222
        img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        bs, _, feat_h, feat_w = cls_score.shape

        center_heatmap_pred = get_local_maximum(cls_score, kernel=kernel)

        *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap(
            center_heatmap_pred, k=topk)
        batch_scores, batch_index, batch_topk_labels = batch_dets

        regression = transpose_and_gather_feat(reg_pred, batch_index)
        regression = regression.view(-1, 8)

        points = torch.cat([topk_xs.view(-1, 1),
                            topk_ys.view(-1, 1).float()],
                           dim=1)
        locations, dimensions, orientations = self.bbox_coder.decode(
            regression, points, batch_topk_labels, cam2imgs, trans_mats)

        batch_bboxes = torch.cat((locations, dimensions, orientations), dim=1)
        batch_bboxes = batch_bboxes.view(bs, -1, self.bbox_code_size)
        return batch_bboxes, batch_scores, batch_topk_labels

ZCMax's avatar
ZCMax committed
244
245
246
    def get_predictions(self, labels_3d, centers_2d, gt_locations,
                        gt_dimensions, gt_orientations, indices,
                        batch_img_metas, pred_reg):
247
248
249
        """Prepare predictions for computing loss.

        Args:
ZCMax's avatar
ZCMax committed
250
            labels_3d (Tensor): Labels of each 3D box.
251
                shape (B, max_objs, )
ZCMax's avatar
ZCMax committed
252
            centers_2d (Tensor): Coords of each projected 3D box
253
254
255
256
257
258
259
                center on image. shape (B * max_objs, 2)
            gt_locations (Tensor): Coords of each 3D box's location.
                shape (B * max_objs, 3)
            gt_dimensions (Tensor): Dimensions of each 3D box.
                shape (N, 3)
            gt_orientations (Tensor): Orientation(yaw) of each 3D box.
                shape (N, 1)
260
            indices (Tensor): Indices of the existence of the 3D box.
261
                shape (B * max_objs, )
ZCMax's avatar
ZCMax committed
262
263
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
264
265
266
267
268
269
270
271
            pre_reg (Tensor): Box regression map.
                shape (B, channel, H , W).

        Returns:
            dict: the dict has components below:
            - bbox3d_yaws (:obj:`CameraInstance3DBoxes`):
                bbox calculated using pred orientations.
            - bbox3d_dims (:obj:`CameraInstance3DBoxes`):
272
                bbox calculated using pred dimensions.
273
274
275
276
277
278
            - bbox3d_locs (:obj:`CameraInstance3DBoxes`):
                bbox calculated using pred locations.
        """
        batch, channel = pred_reg.shape[0], pred_reg.shape[1]
        w = pred_reg.shape[3]
        cam2imgs = torch.stack([
ZCMax's avatar
ZCMax committed
279
280
            gt_locations.new_tensor(img_meta['cam2img'])
            for img_meta in batch_img_metas
281
282
        ])
        trans_mats = torch.stack([
ZCMax's avatar
ZCMax committed
283
284
            gt_locations.new_tensor(img_meta['trans_mat'])
            for img_meta in batch_img_metas
285
        ])
ZCMax's avatar
ZCMax committed
286
287
288
        centers_2d_inds = centers_2d[:, 1] * w + centers_2d[:, 0]
        centers_2d_inds = centers_2d_inds.view(batch, -1)
        pred_regression = transpose_and_gather_feat(pred_reg, centers_2d_inds)
289
290
        pred_regression_pois = pred_regression.view(-1, channel)
        locations, dimensions, orientations = self.bbox_coder.decode(
ZCMax's avatar
ZCMax committed
291
            pred_regression_pois, centers_2d, labels_3d, cam2imgs, trans_mats,
292
293
            gt_locations)

294
295
        locations, dimensions, orientations = locations[indices], dimensions[
            indices], orientations[indices]
296
297
298

        locations[:, 1] += dimensions[:, 1] / 2

299
        gt_locations = gt_locations[indices]
300
301
302
303
304

        assert len(locations) == len(gt_locations)
        assert len(dimensions) == len(gt_dimensions)
        assert len(orientations) == len(gt_orientations)
        bbox3d_yaws = self.bbox_coder.encode(gt_locations, gt_dimensions,
ZCMax's avatar
ZCMax committed
305
                                             orientations, batch_img_metas)
306
        bbox3d_dims = self.bbox_coder.encode(gt_locations, dimensions,
ZCMax's avatar
ZCMax committed
307
                                             gt_orientations, batch_img_metas)
308
        bbox3d_locs = self.bbox_coder.encode(locations, gt_dimensions,
ZCMax's avatar
ZCMax committed
309
                                             gt_orientations, batch_img_metas)
310
311
312
313
314

        pred_bboxes = dict(ori=bbox3d_yaws, dim=bbox3d_dims, loc=bbox3d_locs)

        return pred_bboxes

ZCMax's avatar
ZCMax committed
315
    def get_targets(self, batch_gt_instances_3d, feat_shape, batch_img_metas):
316
        """Get training targets for batch images.
317

318
        Args:
ZCMax's avatar
ZCMax committed
319
320
321
322
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instance_3d.  It usually includes ``bboxes``、``labels``
                、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
                attributes.
323
324
            feat_shape (tuple[int]): Feature map shape with value,
                shape (B, _, H, W).
ZCMax's avatar
ZCMax committed
325
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
326
327
328
329
330
                image size, scaling factor, etc.

        Returns:
            tuple[Tensor, dict]: The Tensor value is the targets of
                center heatmap, the dict has components below:
ZCMax's avatar
ZCMax committed
331
              - gt_centers_2d (Tensor): Coords of each projected 3D box
332
                    center on image. shape (B * max_objs, 2)
ZCMax's avatar
ZCMax committed
333
              - gt_labels_3d (Tensor): Labels of each 3D box.
334
335
                    shape (B, max_objs, )
              - indices (Tensor): Indices of the existence of the 3D box.
336
                    shape (B * max_objs, )
337
              - affine_indices (Tensor): Indices of the affine of the 3D box.
338
339
340
341
342
343
344
345
346
347
348
                    shape (N, )
              - gt_locs (Tensor): Coords of each 3D box's location.
                    shape (N, 3)
              - gt_dims (Tensor): Dimensions of each 3D box.
                    shape (N, 3)
              - gt_yaws (Tensor): Orientation(yaw) of each 3D box.
                    shape (N, 1)
              - gt_cors (Tensor): Coords of the corners of each 3D box.
                    shape (N, 8, 3)
        """

ZCMax's avatar
ZCMax committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        gt_bboxes = [
            gt_instances_3d.bboxes for gt_instances_3d in batch_gt_instances_3d
        ]
        gt_labels = [
            gt_instances_3d.labels for gt_instances_3d in batch_gt_instances_3d
        ]
        gt_bboxes_3d = [
            gt_instances_3d.bboxes_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        gt_labels_3d = [
            gt_instances_3d.labels_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        centers_2d = [
            gt_instances_3d.centers_2d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        img_shape = batch_img_metas[0]['pad_shape']

369
370
        reg_mask = torch.stack([
            gt_bboxes[0].new_tensor(
ZCMax's avatar
ZCMax committed
371
372
                not img_meta['affine_aug'], dtype=torch.bool)
            for img_meta in batch_img_metas
373
374
375
376
377
378
379
380
381
382
383
384
385
        ])

        img_h, img_w = img_shape[:2]
        bs, _, feat_h, feat_w = feat_shape

        width_ratio = float(feat_w / img_w)  # 1/4
        height_ratio = float(feat_h / img_h)  # 1/4

        assert width_ratio == height_ratio

        center_heatmap_target = gt_bboxes[-1].new_zeros(
            [bs, self.num_classes, feat_h, feat_w])

ZCMax's avatar
ZCMax committed
386
        gt_centers_2d = centers_2d.copy()
387
388
389
390

        for batch_id in range(bs):
            gt_bbox = gt_bboxes[batch_id]
            gt_label = gt_labels[batch_id]
ZCMax's avatar
ZCMax committed
391
392
            # project centers_2d from input image to feat map
            gt_center_2d = gt_centers_2d[batch_id] * width_ratio
393

ZCMax's avatar
ZCMax committed
394
            for j, center in enumerate(gt_center_2d):
395
396
397
398
399
400
401
402
403
404
405
                center_x_int, center_y_int = center.int()
                scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
                scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
                radius = gaussian_radius([scale_box_h, scale_box_w],
                                         min_overlap=0.7)
                radius = max(0, int(radius))
                ind = gt_label[j]
                gen_gaussian_target(center_heatmap_target[batch_id, ind],
                                    [center_x_int, center_y_int], radius)

        avg_factor = max(1, center_heatmap_target.eq(1).sum())
ZCMax's avatar
ZCMax committed
406
        num_ctrs = [center_2d.shape[0] for center_2d in centers_2d]
407
408
409
410
411
412
        max_objs = max(num_ctrs)

        reg_inds = torch.cat(
            [reg_mask[i].repeat(num_ctrs[i]) for i in range(bs)])

        inds = torch.zeros((bs, max_objs),
ZCMax's avatar
ZCMax committed
413
                           dtype=torch.bool).to(centers_2d[0].device)
414
415
416

        # put gt 3d bboxes to gpu
        gt_bboxes_3d = [
ZCMax's avatar
ZCMax committed
417
            gt_bbox_3d.to(centers_2d[0].device) for gt_bbox_3d in gt_bboxes_3d
418
419
        ]

ZCMax's avatar
ZCMax committed
420
        batch_centers_2d = centers_2d[0].new_zeros((bs, max_objs, 2))
421
422
423
424
425
        batch_labels_3d = gt_labels_3d[0].new_zeros((bs, max_objs))
        batch_gt_locations = \
            gt_bboxes_3d[0].tensor.new_zeros((bs, max_objs, 3))
        for i in range(bs):
            inds[i, :num_ctrs[i]] = 1
ZCMax's avatar
ZCMax committed
426
            batch_centers_2d[i, :num_ctrs[i]] = centers_2d[i]
427
428
429
430
431
            batch_labels_3d[i, :num_ctrs[i]] = gt_labels_3d[i]
            batch_gt_locations[i, :num_ctrs[i]] = \
                gt_bboxes_3d[i].tensor[:, :3]

        inds = inds.flatten()
ZCMax's avatar
ZCMax committed
432
        batch_centers_2d = batch_centers_2d.view(-1, 2) * width_ratio
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        batch_gt_locations = batch_gt_locations.view(-1, 3)

        # filter the empty image, without gt_bboxes_3d
        gt_bboxes_3d = [
            gt_bbox_3d for gt_bbox_3d in gt_bboxes_3d
            if gt_bbox_3d.tensor.shape[0] > 0
        ]

        gt_dimensions = torch.cat(
            [gt_bbox_3d.tensor[:, 3:6] for gt_bbox_3d in gt_bboxes_3d])
        gt_orientations = torch.cat([
            gt_bbox_3d.tensor[:, 6].unsqueeze(-1)
            for gt_bbox_3d in gt_bboxes_3d
        ])
        gt_corners = torch.cat(
            [gt_bbox_3d.corners for gt_bbox_3d in gt_bboxes_3d])

        target_labels = dict(
ZCMax's avatar
ZCMax committed
451
452
            gt_centers_2d=batch_centers_2d.long(),
            gt_labels_3d=batch_labels_3d,
453
454
            indices=inds,
            reg_indices=reg_inds,
455
456
457
458
459
460
461
462
463
464
            gt_locs=batch_gt_locations,
            gt_dims=gt_dimensions,
            gt_yaws=gt_orientations,
            gt_cors=gt_corners)

        return center_heatmap_target, avg_factor, target_labels

    def loss(self,
             cls_scores,
             bbox_preds,
ZCMax's avatar
ZCMax committed
465
466
467
             batch_gt_instances_3d,
             batch_img_metas,
             batch_gt_instances_ignore=None):
468
469
470
471
472
473
474
475
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level.
                shape (num_gt, 4).
            bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
                number is bbox_code_size.
                shape (B, 7, H, W).
ZCMax's avatar
ZCMax committed
476
477
478
479
480
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instance_3d.  It usually includes ``bboxes``、``labels``
                、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
                attributes.
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
481
                image size, scaling factor, etc.
ZCMax's avatar
ZCMax committed
482
483
484
485
            batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
486
487
488
489
490

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == 1
ZCMax's avatar
ZCMax committed
491
492
        assert batch_gt_instances_ignore is None
        center_2d_heatmap = cls_scores[0]
493
494
        pred_reg = bbox_preds[0]

ZCMax's avatar
ZCMax committed
495
496
497
498
        center_2d_heatmap_target, avg_factor, target_labels = \
            self.get_targets(batch_gt_instances_3d,
                             center_2d_heatmap.shape,
                             batch_img_metas)
499
500

        pred_bboxes = self.get_predictions(
ZCMax's avatar
ZCMax committed
501
502
            labels_3d=target_labels['gt_labels_3d'],
            centers_2d=target_labels['gt_centers_2d'],
503
504
505
            gt_locations=target_labels['gt_locs'],
            gt_dimensions=target_labels['gt_dims'],
            gt_orientations=target_labels['gt_yaws'],
506
            indices=target_labels['indices'],
ZCMax's avatar
ZCMax committed
507
            batch_img_metas=batch_img_metas,
508
509
510
            pred_reg=pred_reg)

        loss_cls = self.loss_cls(
ZCMax's avatar
ZCMax committed
511
            center_2d_heatmap, center_2d_heatmap_target, avg_factor=avg_factor)
512

513
        reg_inds = target_labels['reg_indices']
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531

        loss_bbox_oris = self.loss_bbox(
            pred_bboxes['ori'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox_dims = self.loss_bbox(
            pred_bboxes['dim'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox_locs = self.loss_bbox(
            pred_bboxes['loc'].corners[reg_inds, ...],
            target_labels['gt_cors'][reg_inds, ...])

        loss_bbox = loss_bbox_dims + loss_bbox_locs + loss_bbox_oris

        loss_dict = dict(loss_cls=loss_cls, loss_bbox=loss_bbox)

        return loss_dict