fcos_mono3d_head.py 42.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
from typing import List, Optional, Sequence, Tuple
3

twang's avatar
twang committed
4
5
import numpy as np
import torch
Tai-Wang's avatar
Tai-Wang committed
6
from mmcv.cnn import Scale, normal_init
7
from mmengine.data import InstanceData
ZCMax's avatar
ZCMax committed
8
from torch import Tensor
twang's avatar
twang committed
9
10
from torch import nn as nn

zhangshilong's avatar
zhangshilong committed
11
from mmdet3d.models.layers import box3d_multiclass_nms
12
from mmdet3d.registry import MODELS, TASK_UTILS
zhangshilong's avatar
zhangshilong committed
13
14
15
16
from mmdet3d.structures import limit_period, points_img2cam, xywhr2xyxyr
from mmdet3d.utils import (ConfigType, InstanceList, OptConfigType,
                           OptInstanceList)
from mmdet.models.utils import multi_apply, select_single_mlvl
twang's avatar
twang committed
17
18
from .anchor_free_mono3d_head import AnchorFreeMono3DHead

ZCMax's avatar
ZCMax committed
19
20
RangeType = Sequence[Tuple[int, int]]

twang's avatar
twang committed
21
22
23
INF = 1e8


24
@MODELS.register_module()
twang's avatar
twang committed
25
26
27
28
29
30
31
class FCOSMono3DHead(AnchorFreeMono3DHead):
    """Anchor-free head used in FCOS3D.

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
ZCMax's avatar
ZCMax committed
32
        regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple
twang's avatar
twang committed
33
            level points.
ZCMax's avatar
ZCMax committed
34
35
36
        center_sampling (bool): If true, use center sampling. Default: True.
        center_sample_radius (float): Radius of center sampling. Default: 1.5.
        norm_on_bbox (bool): If true, normalize the regression targets
twang's avatar
twang committed
37
            with FPN strides. Default: True.
ZCMax's avatar
ZCMax committed
38
39
40
        centerness_on_reg (bool): If true, position centerness on the
            regress branch. Please refer to
            https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
twang's avatar
twang committed
41
            Default: True.
ZCMax's avatar
ZCMax committed
42
        centerness_alpha (float): Parameter used to adjust the intensity
43
            attenuation from the center to the periphery. Default: 2.5.
ZCMax's avatar
ZCMax committed
44
45
46
47
48
49
        loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
        loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
        loss_dir (:obj:`ConfigDict` or dict): Config of direction classification loss.
        loss_attr (:obj:`ConfigDict` or dict): Config of attribute classification loss.
        loss_centerness (:obj:`ConfigDict` or dict): Config of centerness loss.
        norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and config norm layer.
twang's avatar
twang committed
50
            Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
ZCMax's avatar
ZCMax committed
51
        centerness_branch (tuple[int]): Channels for centerness branch.
twang's avatar
twang committed
52
            Default: (64, ).
ZCMax's avatar
ZCMax committed
53
54
        init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
            dict]): Initialization config dict.
twang's avatar
twang committed
55
56
57
    """  # noqa: E501

    def __init__(self,
ZCMax's avatar
ZCMax committed
58
59
60
61
62
63
64
65
66
                 regress_ranges: RangeType = ((-1, 48), (48, 96), (96, 192),
                                              (192, 384), (384, INF)),
                 center_sampling: bool = True,
                 center_sample_radius: float = 1.5,
                 norm_on_bbox: bool = True,
                 centerness_on_reg: bool = True,
                 centerness_alpha: float = 2.5,
                 loss_cls: ConfigType = dict(
                     type='mmdet.FocalLoss',
twang's avatar
twang committed
67
68
69
70
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
71
72
73
74
75
76
                 loss_bbox: ConfigType = dict(
                     type='mmdet.SmoothL1Loss',
                     beta=1.0 / 9.0,
                     loss_weight=1.0),
                 loss_dir: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
twang's avatar
twang committed
77
78
                     use_sigmoid=False,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
79
80
                 loss_attr: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
twang's avatar
twang committed
81
82
                     use_sigmoid=False,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
83
84
                 loss_centerness: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
twang's avatar
twang committed
85
86
                     use_sigmoid=True,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
87
88
89
90
91
92
93
                 bbox_coder: ConfigType = dict(
                     type='FCOS3DBBoxCoder', code_size=9),
                 norm_cfg: ConfigType = dict(
                     type='GN', num_groups=32, requires_grad=True),
                 centerness_branch: Tuple[int] = (64, ),
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
twang's avatar
twang committed
94
95
96
97
98
99
100
101
102
103
104
105
106
        self.regress_ranges = regress_ranges
        self.center_sampling = center_sampling
        self.center_sample_radius = center_sample_radius
        self.norm_on_bbox = norm_on_bbox
        self.centerness_on_reg = centerness_on_reg
        self.centerness_alpha = centerness_alpha
        self.centerness_branch = centerness_branch
        super().__init__(
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_dir=loss_dir,
            loss_attr=loss_attr,
            norm_cfg=norm_cfg,
107
            init_cfg=init_cfg,
twang's avatar
twang committed
108
            **kwargs)
109
        self.loss_centerness = MODELS.build(loss_centerness)
Tai-Wang's avatar
Tai-Wang committed
110
        bbox_coder['code_size'] = self.bbox_code_size
111
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
twang's avatar
twang committed
112
113
114
115
116
117
118
119

    def _init_layers(self):
        """Initialize layers of the head."""
        super()._init_layers()
        self.conv_centerness_prev = self._init_branch(
            conv_channels=self.centerness_branch,
            conv_strides=(1, ) * len(self.centerness_branch))
        self.conv_centerness = nn.Conv2d(self.centerness_branch[-1], 1, 1)
Tai-Wang's avatar
Tai-Wang committed
120
        self.scale_dim = 3  # only for offset, depth and size regression
twang's avatar
twang committed
121
        self.scales = nn.ModuleList([
Tai-Wang's avatar
Tai-Wang committed
122
123
124
            nn.ModuleList([Scale(1.0) for _ in range(self.scale_dim)])
            for _ in self.strides
        ])
twang's avatar
twang committed
125

Tai-Wang's avatar
Tai-Wang committed
126
127
128
129
130
131
132
133
134
135
136
137
138
    def init_weights(self):
        """Initialize weights of the head.

        We currently still use the customized init_weights because the default
        init of DCN triggered by the init_cfg will init conv_offset.weight,
        which mistakenly affects the training stability.
        """
        super().init_weights()
        for m in self.conv_centerness_prev:
            if isinstance(m.conv, nn.Conv2d):
                normal_init(m.conv, std=0.01)
        normal_init(self.conv_centerness, std=0.01)

ZCMax's avatar
ZCMax committed
139
140
141
142
    def forward(
        self, x: Tuple[Tensor]
    ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor],
               List[Tensor]]:
twang's avatar
twang committed
143
144
145
        """Forward features from the upstream network.

        Args:
ZCMax's avatar
ZCMax committed
146
            x (tuple[Tensor]): Features from the upstream network, each is
twang's avatar
twang committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                a 4D-tensor.

        Returns:
            tuple:
                cls_scores (list[Tensor]): Box scores for each scale level,
                    each is a 4D-tensor, the channel number is
                    num_points * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for each scale
                    level, each is a 4D-tensor, the channel number is
                    num_points * bbox_code_size.
                dir_cls_preds (list[Tensor]): Box scores for direction class
                    predictions on each scale level, each is a 4D-tensor,
                    the channel number is num_points * 2. (bin = 2).
                attr_preds (list[Tensor]): Attribute scores for each scale
                    level, each is a 4D-tensor, the channel number is
                    num_points * num_attrs.
                centernesses (list[Tensor]): Centerness for each scale level,
                    each is a 4D-tensor, the channel number is num_points * 1.
        """
Tai-Wang's avatar
Tai-Wang committed
166
        # Note: we use [:5] to filter feats and only return predictions
ZCMax's avatar
ZCMax committed
167
        return multi_apply(self.forward_single, x, self.scales,
Tai-Wang's avatar
Tai-Wang committed
168
                           self.strides)[:5]
twang's avatar
twang committed
169

ZCMax's avatar
ZCMax committed
170
171
    def forward_single(self, x: Tensor, scale: Scale,
                       stride: int) -> Tuple[Tensor, ...]:
172
        """Forward features of a single scale level.
twang's avatar
twang committed
173
174
175
176
177
178
179
180
181
182

        Args:
            x (Tensor): FPN feature maps of the specified stride.
            scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
                the bbox prediction.
            stride (int): The corresponding stride for feature maps, only
                used to normalize the bbox prediction when self.norm_on_bbox
                is True.

        Returns:
183
            tuple: scores for each class, bbox and direction class
twang's avatar
twang committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
                predictions, centerness predictions of input feature maps.
        """
        cls_score, bbox_pred, dir_cls_pred, attr_pred, cls_feat, reg_feat = \
            super().forward_single(x)

        if self.centerness_on_reg:
            clone_reg_feat = reg_feat.clone()
            for conv_centerness_prev_layer in self.conv_centerness_prev:
                clone_reg_feat = conv_centerness_prev_layer(clone_reg_feat)
            centerness = self.conv_centerness(clone_reg_feat)
        else:
            clone_cls_feat = cls_feat.clone()
            for conv_centerness_prev_layer in self.conv_centerness_prev:
                clone_cls_feat = conv_centerness_prev_layer(clone_cls_feat)
            centerness = self.conv_centerness(clone_cls_feat)

Tai-Wang's avatar
Tai-Wang committed
200
201
        bbox_pred = self.bbox_coder.decode(bbox_pred, scale, stride,
                                           self.training, cls_score)
twang's avatar
twang committed
202

Tai-Wang's avatar
Tai-Wang committed
203
204
        return cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness, \
            cls_feat, reg_feat
twang's avatar
twang committed
205
206

    @staticmethod
ZCMax's avatar
ZCMax committed
207
208
    def add_sin_difference(boxes1: Tensor,
                           boxes2: Tensor) -> Tuple[Tensor, Tensor]:
twang's avatar
twang committed
209
210
211
212
213
214
215
216
217
        """Convert the rotation difference to difference in sine function.

        Args:
            boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
                and the 7th dimension is rotation dimension.
            boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
                the 7th dimension is rotation dimension.

        Returns:
218
            tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th
twang's avatar
twang committed
219
220
221
222
223
224
225
226
227
228
229
230
231
                dimensions are changed.
        """
        rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
            boxes2[..., 6:7])
        rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[...,
                                                                         6:7])
        boxes1 = torch.cat(
            [boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1)
        boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]],
                           dim=-1)
        return boxes1, boxes2

    @staticmethod
ZCMax's avatar
ZCMax committed
232
233
234
235
236
    def get_direction_target(reg_targets: Tensor,
                             dir_offset: int = 0,
                             dir_limit_offset: float = 0.0,
                             num_bins: int = 2,
                             one_hot: bool = True) -> Tensor:
twang's avatar
twang committed
237
238
239
240
        """Encode direction to 0 ~ num_bins-1.

        Args:
            reg_targets (torch.Tensor): Bbox regression targets.
241
242
243
244
245
246
247
            dir_offset (int, optional): Direction offset. Default to 0.
            dir_limit_offset (float, optional): Offset to set the direction
                range. Default to 0.0.
            num_bins (int, optional): Number of bins to divide 2*PI.
                Default to 2.
            one_hot (bool, optional): Whether to encode as one hot.
                Default to True.
twang's avatar
twang committed
248
249
250
251
252

        Returns:
            torch.Tensor: Encoded direction targets.
        """
        rot_gt = reg_targets[..., 6]
253
254
        offset_rot = limit_period(rot_gt - dir_offset, dir_limit_offset,
                                  2 * np.pi)
twang's avatar
twang committed
255
256
257
258
259
260
261
262
263
264
265
266
267
        dir_cls_targets = torch.floor(offset_rot /
                                      (2 * np.pi / num_bins)).long()
        dir_cls_targets = torch.clamp(dir_cls_targets, min=0, max=num_bins - 1)
        if one_hot:
            dir_targets = torch.zeros(
                *list(dir_cls_targets.shape),
                num_bins,
                dtype=reg_targets.dtype,
                device=dir_cls_targets.device)
            dir_targets.scatter_(dir_cls_targets.unsqueeze(dim=-1).long(), 1.0)
            dir_cls_targets = dir_targets
        return dir_cls_targets

ZCMax's avatar
ZCMax committed
268
269
270
271
272
273
274
275
276
277
278
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            dir_cls_preds: List[Tensor],
            attr_preds: List[Tensor],
            centernesses: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_gt_instacnes: InstanceList,
            batch_img_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
twang's avatar
twang committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_classes.
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level, each is a 4D-tensor, the channel number is
                num_points * bbox_code_size.
            dir_cls_preds (list[Tensor]): Box scores for direction class
                predictions on each scale level, each is a 4D-tensor,
                the channel number is num_points * 2. (bin = 2)
            attr_preds (list[Tensor]): Attribute scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_attrs.
            centernesses (list[Tensor]): Centerness for each scale level, each
                is a 4D-tensor, the channel number is num_points * 1.
296
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
ZCMax's avatar
ZCMax committed
297
298
299
300
                gt_instance_3d.  It usually includes ``bboxes_3d``、`
                `labels_3d``、``depths``、``centers_2d`` and attributes.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance.  It usually includes ``bboxes``、``labels``.
301
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
twang's avatar
twang committed
302
                image size, scaling factor, etc.
303
304
305
306
            batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
twang's avatar
twang committed
307
308
309
310
311
312
313
314
315
316

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len(
            attr_preds)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels_3d, bbox_targets_3d, centerness_targets, attr_targets = \
ZCMax's avatar
ZCMax committed
317
318
            self.get_targets(all_level_points, batch_gt_instances_3d,
                             batch_gt_instacnes)
twang's avatar
twang committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds, dir_cls_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, sum(self.group_reg_dims))
            for bbox_pred in bbox_preds
        ]
        flatten_dir_cls_preds = [
            dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)
            for dir_cls_pred in dir_cls_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_dir_cls_preds = torch.cat(flatten_dir_cls_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels_3d = torch.cat(labels_3d)
        flatten_bbox_targets_3d = torch.cat(bbox_targets_3d)
        flatten_centerness_targets = torch.cat(centerness_targets)

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = ((flatten_labels_3d >= 0)
                    & (flatten_labels_3d < bg_class_ind)).nonzero().reshape(-1)
        num_pos = len(pos_inds)

        loss_cls = self.loss_cls(
            flatten_cls_scores,
            flatten_labels_3d,
            avg_factor=num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_dir_cls_preds = flatten_dir_cls_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]

        if self.pred_attrs:
            flatten_attr_preds = [
                attr_pred.permute(0, 2, 3, 1).reshape(-1, self.num_attrs)
                for attr_pred in attr_preds
            ]
            flatten_attr_preds = torch.cat(flatten_attr_preds)
            flatten_attr_targets = torch.cat(attr_targets)
            pos_attr_preds = flatten_attr_preds[pos_inds]

        if num_pos > 0:
            pos_bbox_targets_3d = flatten_bbox_targets_3d[pos_inds]
            pos_centerness_targets = flatten_centerness_targets[pos_inds]
            if self.pred_attrs:
                pos_attr_targets = flatten_attr_targets[pos_inds]
            bbox_weights = pos_centerness_targets.new_ones(
                len(pos_centerness_targets), sum(self.group_reg_dims))
            equal_weights = pos_centerness_targets.new_ones(
                pos_centerness_targets.shape)

            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
                assert len(code_weight) == sum(self.group_reg_dims)
                bbox_weights = bbox_weights * bbox_weights.new_tensor(
                    code_weight)

            if self.use_direction_classifier:
                pos_dir_cls_targets = self.get_direction_target(
388
389
390
391
                    pos_bbox_targets_3d,
                    self.dir_offset,
                    self.dir_limit_offset,
                    one_hot=False)
twang's avatar
twang committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482

            if self.diff_rad_by_sin:
                pos_bbox_preds, pos_bbox_targets_3d = self.add_sin_difference(
                    pos_bbox_preds, pos_bbox_targets_3d)

            loss_offset = self.loss_bbox(
                pos_bbox_preds[:, :2],
                pos_bbox_targets_3d[:, :2],
                weight=bbox_weights[:, :2],
                avg_factor=equal_weights.sum())
            loss_depth = self.loss_bbox(
                pos_bbox_preds[:, 2],
                pos_bbox_targets_3d[:, 2],
                weight=bbox_weights[:, 2],
                avg_factor=equal_weights.sum())
            loss_size = self.loss_bbox(
                pos_bbox_preds[:, 3:6],
                pos_bbox_targets_3d[:, 3:6],
                weight=bbox_weights[:, 3:6],
                avg_factor=equal_weights.sum())
            loss_rotsin = self.loss_bbox(
                pos_bbox_preds[:, 6],
                pos_bbox_targets_3d[:, 6],
                weight=bbox_weights[:, 6],
                avg_factor=equal_weights.sum())
            loss_velo = None
            if self.pred_velo:
                loss_velo = self.loss_bbox(
                    pos_bbox_preds[:, 7:9],
                    pos_bbox_targets_3d[:, 7:9],
                    weight=bbox_weights[:, 7:9],
                    avg_factor=equal_weights.sum())

            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)

            # direction classification loss
            loss_dir = None
            # TODO: add more check for use_direction_classifier
            if self.use_direction_classifier:
                loss_dir = self.loss_dir(
                    pos_dir_cls_preds,
                    pos_dir_cls_targets,
                    equal_weights,
                    avg_factor=equal_weights.sum())

            # attribute classification loss
            loss_attr = None
            if self.pred_attrs:
                loss_attr = self.loss_attr(
                    pos_attr_preds,
                    pos_attr_targets,
                    pos_centerness_targets,
                    avg_factor=pos_centerness_targets.sum())

        else:
            # need absolute due to possible negative delta x/y
            loss_offset = pos_bbox_preds[:, :2].sum()
            loss_depth = pos_bbox_preds[:, 2].sum()
            loss_size = pos_bbox_preds[:, 3:6].sum()
            loss_rotsin = pos_bbox_preds[:, 6].sum()
            loss_velo = None
            if self.pred_velo:
                loss_velo = pos_bbox_preds[:, 7:9].sum()
            loss_centerness = pos_centerness.sum()
            loss_dir = None
            if self.use_direction_classifier:
                loss_dir = pos_dir_cls_preds.sum()
            loss_attr = None
            if self.pred_attrs:
                loss_attr = pos_attr_preds.sum()

        loss_dict = dict(
            loss_cls=loss_cls,
            loss_offset=loss_offset,
            loss_depth=loss_depth,
            loss_size=loss_size,
            loss_rotsin=loss_rotsin,
            loss_centerness=loss_centerness)

        if loss_velo is not None:
            loss_dict['loss_velo'] = loss_velo

        if loss_dir is not None:
            loss_dict['loss_dir'] = loss_dir

        if loss_attr is not None:
            loss_dict['loss_attr'] = loss_attr

        return loss_dict

ZCMax's avatar
ZCMax committed
483
484
485
486
487
488
489
490
491
    def predict_by_feat(self,
                        cls_scores: List[Tensor],
                        bbox_preds: List[Tensor],
                        dir_cls_preds: List[Tensor],
                        attr_preds: List[Tensor],
                        centernesses: List[Tensor],
                        batch_img_metas: Optional[List[dict]] = None,
                        cfg: OptConfigType = None,
                        rescale: bool = False) -> InstanceList:
twang's avatar
twang committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        """Transform network output for a batch into bbox predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level
                Has shape (N, num_points * num_classes, H, W)
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level with shape (N, num_points * 4, H, W)
            dir_cls_preds (list[Tensor]): Box scores for direction class
                predictions on each scale level, each is a 4D-tensor,
                the channel number is num_points * 2. (bin = 2)
            attr_preds (list[Tensor]): Attribute scores for each scale level
                Has shape (N, num_points * num_attrs, H, W)
            centernesses (list[Tensor]): Centerness for each scale level with
                shape (N, num_points * 1, H, W)
506
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
twang's avatar
twang committed
507
                image size, scaling factor, etc.
ZCMax's avatar
ZCMax committed
508
509
510
511
512
            cfg (ConfigDict, optional): Test / postprocessing
                configuration, if None, test_cfg would be used.
                Defaults to None.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.
twang's avatar
twang committed
513
514

        Returns:
ZCMax's avatar
ZCMax committed
515
516
517
518
519
520
521
522
523
            list[:obj:`InstanceData`]: Object detection results of each image
            after the post process. Each item usually contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                  (num_instances, C), where C >= 7.
twang's avatar
twang committed
524
525
526
527
528
529
        """
        assert len(cls_scores) == len(bbox_preds) == len(dir_cls_preds) == \
            len(centernesses) == len(attr_preds)
        num_levels = len(cls_scores)

        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
ZCMax's avatar
ZCMax committed
530
        # TODO: refactor using prior_generator
twang's avatar
twang committed
531
532
533
        mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                      bbox_preds[0].device)
        result_list = []
534
        for img_id in range(len(batch_img_metas)):
ZCMax's avatar
ZCMax committed
535
536
537
538
            img_meta = batch_img_metas[img_id]
            cls_score_list = select_single_mlvl(cls_scores, img_id)
            bbox_pred_list = select_single_mlvl(bbox_preds, img_id)

twang's avatar
twang committed
539
            if self.use_direction_classifier:
ZCMax's avatar
ZCMax committed
540
                dir_cls_pred_list = select_single_mlvl(dir_cls_preds, img_id)
twang's avatar
twang committed
541
542
543
544
545
546
            else:
                dir_cls_pred_list = [
                    cls_scores[i][img_id].new_full(
                        [2, *cls_scores[i][img_id].shape[1:]], 0).detach()
                    for i in range(num_levels)
                ]
ZCMax's avatar
ZCMax committed
547

twang's avatar
twang committed
548
            if self.pred_attrs:
ZCMax's avatar
ZCMax committed
549
                attr_pred_list = select_single_mlvl(attr_preds, img_id)
twang's avatar
twang committed
550
551
552
553
554
555
556
            else:
                attr_pred_list = [
                    cls_scores[i][img_id].new_full(
                        [self.num_attrs, *cls_scores[i][img_id].shape[1:]],
                        self.attr_background_label).detach()
                    for i in range(num_levels)
                ]
ZCMax's avatar
ZCMax committed
557
558
559
560
561
562
563
564
565
566
567
568

            centerness_pred_list = select_single_mlvl(centernesses, img_id)
            results = self._predict_by_feat_single(
                cls_score_list=cls_score_list,
                bbox_pred_list=bbox_pred_list,
                dir_cls_pred_list=dir_cls_pred_list,
                attr_pred_list=attr_pred_list,
                centerness_pred_list=centerness_pred_list,
                mlvl_points=mlvl_points,
                img_meta=img_meta,
                cfg=cfg,
                rescale=rescale)
569
            result_list.append(results)
twang's avatar
twang committed
570
571
        return result_list

ZCMax's avatar
ZCMax committed
572
573
574
575
576
577
578
579
580
581
    def _predict_by_feat_single(self,
                                cls_score_list: List[Tensor],
                                bbox_pred_list: List[Tensor],
                                dir_cls_pred_list: List[Tensor],
                                attr_pred_list: List[Tensor],
                                centerness_pred_list: List[Tensor],
                                mlvl_points: Tensor,
                                img_meta: dict,
                                cfg: ConfigType,
                                rescale: bool = False) -> InstanceData:
twang's avatar
twang committed
582
583
584
585
586
587
588
589
        """Transform outputs for a single batch item into bbox predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for a single scale level
                Has shape (num_points * num_classes, H, W).
            bbox_preds (list[Tensor]): Box energies / deltas for a single scale
                level with shape (num_points * bbox_code_size, H, W).
            dir_cls_preds (list[Tensor]): Box scores for direction class
590
                predictions on a single scale level with shape
twang's avatar
twang committed
591
592
593
594
595
596
597
                (num_points * 2, H, W)
            attr_preds (list[Tensor]): Attribute scores for each scale level
                Has shape (N, num_points * num_attrs, H, W)
            centernesses (list[Tensor]): Centerness for a single scale level
                with shape (num_points, H, W).
            mlvl_points (list[Tensor]): Box reference for a single scale level
                with shape (num_total_points, 2).
598
            img_meta (dict): Metadata of input image.
twang's avatar
twang committed
599
600
601
602
603
            cfg (mmcv.Config): Test / postprocessing configuration,
                if None, test_cfg would be used.
            rescale (bool): If True, return boxes in original image space.

        Returns:
ZCMax's avatar
ZCMax committed
604
605
606
607
608
609
610
611
612
613
            :obj:`InstanceData`: 3D Detection results of each image
            after the post process.
            Each item usually contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                  (num_instances, C), where C >= 7.
twang's avatar
twang committed
614
        """
615
616
        view = np.array(img_meta['cam2img'])
        scale_factor = img_meta['scale_factor']
twang's avatar
twang committed
617
        cfg = self.test_cfg if cfg is None else cfg
ZCMax's avatar
ZCMax committed
618
        assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_points)
619
        mlvl_centers_2d = []
twang's avatar
twang committed
620
621
622
623
624
625
626
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_dir_scores = []
        mlvl_attr_scores = []
        mlvl_centerness = []

        for cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness, \
ZCMax's avatar
ZCMax committed
627
628
629
                points in zip(cls_score_list, bbox_pred_list,
                              dir_cls_pred_list, attr_pred_list,
                              centerness_pred_list, mlvl_points):
twang's avatar
twang committed
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            dir_cls_pred = dir_cls_pred.permute(1, 2, 0).reshape(-1, 2)
            dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1]
            attr_pred = attr_pred.permute(1, 2, 0).reshape(-1, self.num_attrs)
            attr_score = torch.max(attr_pred, dim=-1)[1]
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2,
                                          0).reshape(-1,
                                                     sum(self.group_reg_dims))
            bbox_pred = bbox_pred[:, :self.bbox_code_size]
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                dir_cls_pred = dir_cls_pred[topk_inds, :]
                centerness = centerness[topk_inds]
                dir_cls_score = dir_cls_score[topk_inds]
                attr_score = attr_score[topk_inds]
            # change the offset to actual center predictions
            bbox_pred[:, :2] = points - bbox_pred[:, :2]
            if rescale:
                bbox_pred[:, :2] /= bbox_pred[:, :2].new_tensor(scale_factor)
            pred_center2d = bbox_pred[:, :3].clone()
659
            bbox_pred[:, :3] = points_img2cam(bbox_pred[:, :3], view)
660
            mlvl_centers_2d.append(pred_center2d)
twang's avatar
twang committed
661
662
663
664
665
666
            mlvl_bboxes.append(bbox_pred)
            mlvl_scores.append(scores)
            mlvl_dir_scores.append(dir_cls_score)
            mlvl_attr_scores.append(attr_score)
            mlvl_centerness.append(centerness)

667
        mlvl_centers_2d = torch.cat(mlvl_centers_2d)
twang's avatar
twang committed
668
669
670
671
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        mlvl_dir_scores = torch.cat(mlvl_dir_scores)

        # change local yaw to global yaw for 3D nms
672
        cam2img = mlvl_centers_2d.new_zeros((4, 4))
Tai-Wang's avatar
Tai-Wang committed
673
        cam2img[:view.shape[0], :view.shape[1]] = \
674
675
            mlvl_centers_2d.new_tensor(view)
        mlvl_bboxes = self.bbox_coder.decode_yaw(mlvl_bboxes, mlvl_centers_2d,
Tai-Wang's avatar
Tai-Wang committed
676
677
678
                                                 mlvl_dir_scores,
                                                 self.dir_offset, cam2img)

679
        mlvl_bboxes_for_nms = xywhr2xyxyr(img_meta['box_type_3d'](
twang's avatar
twang committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
            mlvl_bboxes, box_dim=self.bbox_code_size,
            origin=(0.5, 0.5, 0.5)).bev)

        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
        # BG cat_id: num_class
        mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
        mlvl_attr_scores = torch.cat(mlvl_attr_scores)
        mlvl_centerness = torch.cat(mlvl_centerness)
        # no scale_factors in box3d_multiclass_nms
        # Then we multiply it from outside
        mlvl_nms_scores = mlvl_scores * mlvl_centerness[:, None]
        results = box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms,
                                       mlvl_nms_scores, cfg.score_thr,
                                       cfg.max_per_img, cfg, mlvl_dir_scores,
                                       mlvl_attr_scores)
        bboxes, scores, labels, dir_scores, attrs = results
        attrs = attrs.to(labels.dtype)  # change data type to int
699
        bboxes = img_meta['box_type_3d'](
700
            bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
twang's avatar
twang committed
701
        # Note that the predictions use origin (0.5, 0.5, 0.5)
702
        # Due to the ground truth centers_2d are the gravity center of objects
703
704
        # v0.10.0 fix inplace operation to the input tensor of cam_box3d
        # So here we also need to add origin=(0.5, 0.5, 0.5)
twang's avatar
twang committed
705
706
707
        if not self.pred_attrs:
            attrs = None

708
709
710
711
712
713
714
        results = InstanceData()
        results.bboxes_3d = bboxes
        results.scores_3d = scores
        results.labels_3d = labels
        results.attr_labels = attrs

        return results
twang's avatar
twang committed
715

ZCMax's avatar
ZCMax committed
716
717
718
719
720
721
722
723
    def _get_points_single(self,
                           featmap_size: Tuple[int],
                           stride: int,
                           dtype: torch.dtype,
                           device: torch.device,
                           flatten: bool = False) -> Tensor:
        """Get points of a single scale level.

twang's avatar
twang committed
724
        Args:
ZCMax's avatar
ZCMax committed
725
726
727
728
729
730
            featmap_size (tuple[int]): Single scale level feature map size.
            stride (int): Downsample factor of the feature map.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.
            flatten (bool): Whether to flatten the tensor.
                Defaults to False.
twang's avatar
twang committed
731
732

        Returns:
ZCMax's avatar
ZCMax committed
733
            Tensor: points of each image.
twang's avatar
twang committed
734
735
736
737
738
739
        """
        y, x = super()._get_points_single(featmap_size, stride, dtype, device)
        points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
                             dim=-1) + stride // 2
        return points

ZCMax's avatar
ZCMax committed
740
741
742
743
744
745
    def get_targets(
        self,
        points: List[Tensor],
        batch_gt_instances_3d: InstanceList,
        batch_gt_instances: InstanceList,
    ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]:
twang's avatar
twang committed
746
747
748
749
750
751
        """Compute regression, classification and centerss targets for points
        in multiple images.

        Args:
            points (list[Tensor]): Points of each fpn level, each has shape
                (num_points, 2).
752
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
ZCMax's avatar
ZCMax committed
753
754
755
756
                gt_instance_3d.  It usually includes ``bboxes_3d``、
                ``labels_3d``、``depths``、``centers_2d`` and attributes.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance.  It usually includes ``bboxes``、``labels``.
twang's avatar
twang committed
757
758

        Returns:
ZCMax's avatar
ZCMax committed
759
760
761
762
763
764
765
766
767
            tuple: Targets of each level.

            - concat_lvl_labels_3d (list[Tensor]): 3D Labels of each level.
            - concat_lvl_bbox_targets_3d (list[Tensor]): 3D BBox targets of
                each level.
            - concat_lvl_centerness_targets (list[Tensor]): Centerness targets
                of each level.
            - concat_lvl_attr_targets (list[Tensor]): Attribute targets of
                each level.
twang's avatar
twang committed
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        """
        assert len(points) == len(self.regress_ranges)
        num_levels = len(points)
        # expand regress ranges to align with points
        expanded_regress_ranges = [
            points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
                points[i]) for i in range(num_levels)
        ]
        # concat all levels points and regress ranges
        concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
        concat_points = torch.cat(points, dim=0)

        # the number of points per img, per lvl
        num_points = [center.size(0) for center in points]

783
784
        if 'attr_labels' not in batch_gt_instances_3d[0]:
            for gt_instances_3d in batch_gt_instances_3d:
ZCMax's avatar
ZCMax committed
785
786
787
788
789
                gt_instances_3d.attr_labels = \
                    gt_instances_3d.labels_3d.new_full(
                        gt_instances_3d.labels_3d.shape,
                        self.attr_background_label
                    )
790

twang's avatar
twang committed
791
792
793
794
        # get labels and bbox_targets of each image
        _, _, labels_3d_list, bbox_targets_3d_list, centerness_targets_list, \
            attr_targets_list = multi_apply(
                self._get_target_single,
795
                batch_gt_instances_3d,
ZCMax's avatar
ZCMax committed
796
                batch_gt_instances,
twang's avatar
twang committed
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
                points=concat_points,
                regress_ranges=concat_regress_ranges,
                num_points_per_lvl=num_points)

        # split to per img, per level
        labels_3d_list = [
            labels_3d.split(num_points, 0) for labels_3d in labels_3d_list
        ]
        bbox_targets_3d_list = [
            bbox_targets_3d.split(num_points, 0)
            for bbox_targets_3d in bbox_targets_3d_list
        ]
        centerness_targets_list = [
            centerness_targets.split(num_points, 0)
            for centerness_targets in centerness_targets_list
        ]
        attr_targets_list = [
            attr_targets.split(num_points, 0)
            for attr_targets in attr_targets_list
        ]

        # concat per level image
        concat_lvl_labels_3d = []
        concat_lvl_bbox_targets_3d = []
        concat_lvl_centerness_targets = []
        concat_lvl_attr_targets = []
        for i in range(num_levels):
            concat_lvl_labels_3d.append(
                torch.cat([labels[i] for labels in labels_3d_list]))
            concat_lvl_centerness_targets.append(
                torch.cat([
                    centerness_targets[i]
                    for centerness_targets in centerness_targets_list
                ]))
            bbox_targets_3d = torch.cat([
                bbox_targets_3d[i] for bbox_targets_3d in bbox_targets_3d_list
            ])
            concat_lvl_attr_targets.append(
                torch.cat(
                    [attr_targets[i] for attr_targets in attr_targets_list]))
            if self.norm_on_bbox:
                bbox_targets_3d[:, :
                                2] = bbox_targets_3d[:, :2] / self.strides[i]
            concat_lvl_bbox_targets_3d.append(bbox_targets_3d)
        return concat_lvl_labels_3d, concat_lvl_bbox_targets_3d, \
            concat_lvl_centerness_targets, concat_lvl_attr_targets

ZCMax's avatar
ZCMax committed
844
845
846
847
    def _get_target_single(
            self, gt_instances_3d: InstanceData, gt_instances: InstanceData,
            points: Tensor, regress_ranges: Tensor,
            num_points_per_lvl: List[int]) -> Tuple[Tensor, ...]:
twang's avatar
twang committed
848
849
        """Compute regression and classification targets for a single image."""
        num_points = points.size(0)
850
        num_gts = len(gt_instances_3d)
ZCMax's avatar
ZCMax committed
851
852
        gt_bboxes = gt_instances.bboxes
        gt_labels = gt_instances.labels
853
854
        gt_bboxes_3d = gt_instances_3d.bboxes_3d
        gt_labels_3d = gt_instances_3d.labels_3d
855
        centers_2d = gt_instances_3d.centers_2d
856
857
858
        depths = gt_instances_3d.depths
        attr_labels = gt_instances_3d.attr_labels

twang's avatar
twang committed
859
860
861
862
863
864
865
866
867
868
869
870
        if not isinstance(gt_bboxes_3d, torch.Tensor):
            gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device)
        if num_gts == 0:
            return gt_labels.new_full((num_points,), self.background_label), \
                   gt_bboxes.new_zeros((num_points, 4)), \
                   gt_labels_3d.new_full(
                       (num_points,), self.background_label), \
                   gt_bboxes_3d.new_zeros((num_points, self.bbox_code_size)), \
                   gt_bboxes_3d.new_zeros((num_points,)), \
                   attr_labels.new_full(
                       (num_points,), self.attr_background_label)

871
872
873
874
        # change orientation to local yaw
        gt_bboxes_3d[..., 6] = -torch.atan2(
            gt_bboxes_3d[..., 0], gt_bboxes_3d[..., 2]) + gt_bboxes_3d[..., 6]

twang's avatar
twang committed
875
876
877
878
879
880
        areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
            gt_bboxes[:, 3] - gt_bboxes[:, 1])
        areas = areas[None].repeat(num_points, 1)
        regress_ranges = regress_ranges[:, None, :].expand(
            num_points, num_gts, 2)
        gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
881
        centers_2d = centers_2d[None].expand(num_points, num_gts, 2)
twang's avatar
twang committed
882
883
884
885
886
887
888
        gt_bboxes_3d = gt_bboxes_3d[None].expand(num_points, num_gts,
                                                 self.bbox_code_size)
        depths = depths[None, :, None].expand(num_points, num_gts, 1)
        xs, ys = points[:, 0], points[:, 1]
        xs = xs[:, None].expand(num_points, num_gts)
        ys = ys[:, None].expand(num_points, num_gts)

889
890
        delta_xs = (xs - centers_2d[..., 0])[..., None]
        delta_ys = (ys - centers_2d[..., 1])[..., None]
twang's avatar
twang committed
891
892
893
894
895
896
897
898
899
900
901
902
903
        bbox_targets_3d = torch.cat(
            (delta_xs, delta_ys, depths, gt_bboxes_3d[..., 3:]), dim=-1)

        left = xs - gt_bboxes[..., 0]
        right = gt_bboxes[..., 2] - xs
        top = ys - gt_bboxes[..., 1]
        bottom = gt_bboxes[..., 3] - ys
        bbox_targets = torch.stack((left, top, right, bottom), -1)

        assert self.center_sampling is True, 'Setting center_sampling to '\
            'False has not been implemented for FCOS3D.'
        # condition1: inside a `center bbox`
        radius = self.center_sample_radius
904
905
        center_xs = centers_2d[..., 0]
        center_ys = centers_2d[..., 1]
twang's avatar
twang committed
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
        center_gts = torch.zeros_like(gt_bboxes)
        stride = center_xs.new_zeros(center_xs.shape)

        # project the points on current lvl back to the `original` sizes
        lvl_begin = 0
        for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
            lvl_end = lvl_begin + num_points_lvl
            stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
            lvl_begin = lvl_end

        center_gts[..., 0] = center_xs - stride
        center_gts[..., 1] = center_ys - stride
        center_gts[..., 2] = center_xs + stride
        center_gts[..., 3] = center_ys + stride

        cb_dist_left = xs - center_gts[..., 0]
        cb_dist_right = center_gts[..., 2] - xs
        cb_dist_top = ys - center_gts[..., 1]
        cb_dist_bottom = center_gts[..., 3] - ys
        center_bbox = torch.stack(
            (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
        inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0

        # condition2: limit the regression range for each location
        max_regress_distance = bbox_targets.max(-1)[0]
        inside_regress_range = (
            (max_regress_distance >= regress_ranges[..., 0])
            & (max_regress_distance <= regress_ranges[..., 1]))

        # center-based criterion to deal with ambiguity
        dists = torch.sqrt(torch.sum(bbox_targets_3d[..., :2]**2, dim=-1))
        dists[inside_gt_bbox_mask == 0] = INF
        dists[inside_regress_range == 0] = INF
        min_dist, min_dist_inds = dists.min(dim=1)

        labels = gt_labels[min_dist_inds]
        labels_3d = gt_labels_3d[min_dist_inds]
        attr_labels = attr_labels[min_dist_inds]
        labels[min_dist == INF] = self.background_label  # set as BG
        labels_3d[min_dist == INF] = self.background_label  # set as BG
        attr_labels[min_dist == INF] = self.attr_background_label

        bbox_targets = bbox_targets[range(num_points), min_dist_inds]
        bbox_targets_3d = bbox_targets_3d[range(num_points), min_dist_inds]
        relative_dists = torch.sqrt(
            torch.sum(bbox_targets_3d[..., :2]**2,
                      dim=-1)) / (1.414 * stride[:, 0])
        # [N, 1] / [N, 1]
        centerness_targets = torch.exp(-self.centerness_alpha * relative_dists)

        return labels, bbox_targets, labels_3d, bbox_targets_3d, \
            centerness_targets, attr_labels