fcos_mono3d_head.py 42.7 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ZCMax's avatar
ZCMax committed
2
from typing import List, Optional, Sequence, Tuple
3

twang's avatar
twang committed
4
5
import numpy as np
import torch
6
from mmcv.cnn import Scale
7
from mmdet.models.utils import multi_apply, select_single_mlvl
8
9
from mmengine.model import normal_init
from mmengine.structures import InstanceData
ZCMax's avatar
ZCMax committed
10
from torch import Tensor
twang's avatar
twang committed
11
12
from torch import nn as nn

zhangshilong's avatar
zhangshilong committed
13
from mmdet3d.models.layers import box3d_multiclass_nms
14
from mmdet3d.registry import MODELS, TASK_UTILS
zhangshilong's avatar
zhangshilong committed
15
16
17
from mmdet3d.structures import limit_period, points_img2cam, xywhr2xyxyr
from mmdet3d.utils import (ConfigType, InstanceList, OptConfigType,
                           OptInstanceList)
twang's avatar
twang committed
18
19
from .anchor_free_mono3d_head import AnchorFreeMono3DHead

ZCMax's avatar
ZCMax committed
20
21
RangeType = Sequence[Tuple[int, int]]

twang's avatar
twang committed
22
23
24
INF = 1e8


25
@MODELS.register_module()
twang's avatar
twang committed
26
27
28
29
30
31
32
class FCOSMono3DHead(AnchorFreeMono3DHead):
    """Anchor-free head used in FCOS3D.

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
ZCMax's avatar
ZCMax committed
33
        regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple
twang's avatar
twang committed
34
            level points.
ZCMax's avatar
ZCMax committed
35
36
37
        center_sampling (bool): If true, use center sampling. Default: True.
        center_sample_radius (float): Radius of center sampling. Default: 1.5.
        norm_on_bbox (bool): If true, normalize the regression targets
twang's avatar
twang committed
38
            with FPN strides. Default: True.
ZCMax's avatar
ZCMax committed
39
40
41
        centerness_on_reg (bool): If true, position centerness on the
            regress branch. Please refer to
            https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
twang's avatar
twang committed
42
            Default: True.
ZCMax's avatar
ZCMax committed
43
        centerness_alpha (float): Parameter used to adjust the intensity
44
            attenuation from the center to the periphery. Default: 2.5.
ZCMax's avatar
ZCMax committed
45
46
47
48
49
50
        loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
        loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
        loss_dir (:obj:`ConfigDict` or dict): Config of direction classification loss.
        loss_attr (:obj:`ConfigDict` or dict): Config of attribute classification loss.
        loss_centerness (:obj:`ConfigDict` or dict): Config of centerness loss.
        norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and config norm layer.
twang's avatar
twang committed
51
            Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
ZCMax's avatar
ZCMax committed
52
        centerness_branch (tuple[int]): Channels for centerness branch.
twang's avatar
twang committed
53
            Default: (64, ).
ZCMax's avatar
ZCMax committed
54
55
        init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
            dict]): Initialization config dict.
twang's avatar
twang committed
56
57
58
    """  # noqa: E501

    def __init__(self,
ZCMax's avatar
ZCMax committed
59
60
61
62
63
64
65
66
67
                 regress_ranges: RangeType = ((-1, 48), (48, 96), (96, 192),
                                              (192, 384), (384, INF)),
                 center_sampling: bool = True,
                 center_sample_radius: float = 1.5,
                 norm_on_bbox: bool = True,
                 centerness_on_reg: bool = True,
                 centerness_alpha: float = 2.5,
                 loss_cls: ConfigType = dict(
                     type='mmdet.FocalLoss',
twang's avatar
twang committed
68
69
70
71
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
72
73
74
75
76
77
                 loss_bbox: ConfigType = dict(
                     type='mmdet.SmoothL1Loss',
                     beta=1.0 / 9.0,
                     loss_weight=1.0),
                 loss_dir: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
twang's avatar
twang committed
78
79
                     use_sigmoid=False,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
80
81
                 loss_attr: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
twang's avatar
twang committed
82
83
                     use_sigmoid=False,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
84
85
                 loss_centerness: ConfigType = dict(
                     type='mmdet.CrossEntropyLoss',
twang's avatar
twang committed
86
87
                     use_sigmoid=True,
                     loss_weight=1.0),
ZCMax's avatar
ZCMax committed
88
89
90
91
92
93
94
                 bbox_coder: ConfigType = dict(
                     type='FCOS3DBBoxCoder', code_size=9),
                 norm_cfg: ConfigType = dict(
                     type='GN', num_groups=32, requires_grad=True),
                 centerness_branch: Tuple[int] = (64, ),
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
twang's avatar
twang committed
95
96
97
98
99
100
101
102
103
104
105
106
107
        self.regress_ranges = regress_ranges
        self.center_sampling = center_sampling
        self.center_sample_radius = center_sample_radius
        self.norm_on_bbox = norm_on_bbox
        self.centerness_on_reg = centerness_on_reg
        self.centerness_alpha = centerness_alpha
        self.centerness_branch = centerness_branch
        super().__init__(
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_dir=loss_dir,
            loss_attr=loss_attr,
            norm_cfg=norm_cfg,
108
            init_cfg=init_cfg,
twang's avatar
twang committed
109
            **kwargs)
110
        self.loss_centerness = MODELS.build(loss_centerness)
Tai-Wang's avatar
Tai-Wang committed
111
        bbox_coder['code_size'] = self.bbox_code_size
112
        self.bbox_coder = TASK_UTILS.build(bbox_coder)
twang's avatar
twang committed
113
114
115
116
117
118
119
120

    def _init_layers(self):
        """Initialize layers of the head."""
        super()._init_layers()
        self.conv_centerness_prev = self._init_branch(
            conv_channels=self.centerness_branch,
            conv_strides=(1, ) * len(self.centerness_branch))
        self.conv_centerness = nn.Conv2d(self.centerness_branch[-1], 1, 1)
Tai-Wang's avatar
Tai-Wang committed
121
        self.scale_dim = 3  # only for offset, depth and size regression
twang's avatar
twang committed
122
        self.scales = nn.ModuleList([
Tai-Wang's avatar
Tai-Wang committed
123
124
125
            nn.ModuleList([Scale(1.0) for _ in range(self.scale_dim)])
            for _ in self.strides
        ])
twang's avatar
twang committed
126

Tai-Wang's avatar
Tai-Wang committed
127
128
129
130
131
132
133
134
135
136
137
138
139
    def init_weights(self):
        """Initialize weights of the head.

        We currently still use the customized init_weights because the default
        init of DCN triggered by the init_cfg will init conv_offset.weight,
        which mistakenly affects the training stability.
        """
        super().init_weights()
        for m in self.conv_centerness_prev:
            if isinstance(m.conv, nn.Conv2d):
                normal_init(m.conv, std=0.01)
        normal_init(self.conv_centerness, std=0.01)

ZCMax's avatar
ZCMax committed
140
141
142
143
    def forward(
        self, x: Tuple[Tensor]
    ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor],
               List[Tensor]]:
twang's avatar
twang committed
144
145
146
        """Forward features from the upstream network.

        Args:
ZCMax's avatar
ZCMax committed
147
            x (tuple[Tensor]): Features from the upstream network, each is
twang's avatar
twang committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                a 4D-tensor.

        Returns:
            tuple:
                cls_scores (list[Tensor]): Box scores for each scale level,
                    each is a 4D-tensor, the channel number is
                    num_points * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for each scale
                    level, each is a 4D-tensor, the channel number is
                    num_points * bbox_code_size.
                dir_cls_preds (list[Tensor]): Box scores for direction class
                    predictions on each scale level, each is a 4D-tensor,
                    the channel number is num_points * 2. (bin = 2).
                attr_preds (list[Tensor]): Attribute scores for each scale
                    level, each is a 4D-tensor, the channel number is
                    num_points * num_attrs.
                centernesses (list[Tensor]): Centerness for each scale level,
                    each is a 4D-tensor, the channel number is num_points * 1.
        """
Tai-Wang's avatar
Tai-Wang committed
167
        # Note: we use [:5] to filter feats and only return predictions
ZCMax's avatar
ZCMax committed
168
        return multi_apply(self.forward_single, x, self.scales,
Tai-Wang's avatar
Tai-Wang committed
169
                           self.strides)[:5]
twang's avatar
twang committed
170

ZCMax's avatar
ZCMax committed
171
172
    def forward_single(self, x: Tensor, scale: Scale,
                       stride: int) -> Tuple[Tensor, ...]:
173
        """Forward features of a single scale level.
twang's avatar
twang committed
174
175
176
177
178
179
180
181
182
183

        Args:
            x (Tensor): FPN feature maps of the specified stride.
            scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
                the bbox prediction.
            stride (int): The corresponding stride for feature maps, only
                used to normalize the bbox prediction when self.norm_on_bbox
                is True.

        Returns:
184
            tuple: scores for each class, bbox and direction class
twang's avatar
twang committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
                predictions, centerness predictions of input feature maps.
        """
        cls_score, bbox_pred, dir_cls_pred, attr_pred, cls_feat, reg_feat = \
            super().forward_single(x)

        if self.centerness_on_reg:
            clone_reg_feat = reg_feat.clone()
            for conv_centerness_prev_layer in self.conv_centerness_prev:
                clone_reg_feat = conv_centerness_prev_layer(clone_reg_feat)
            centerness = self.conv_centerness(clone_reg_feat)
        else:
            clone_cls_feat = cls_feat.clone()
            for conv_centerness_prev_layer in self.conv_centerness_prev:
                clone_cls_feat = conv_centerness_prev_layer(clone_cls_feat)
            centerness = self.conv_centerness(clone_cls_feat)

Tai-Wang's avatar
Tai-Wang committed
201
202
        bbox_pred = self.bbox_coder.decode(bbox_pred, scale, stride,
                                           self.training, cls_score)
twang's avatar
twang committed
203

Tai-Wang's avatar
Tai-Wang committed
204
205
        return cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness, \
            cls_feat, reg_feat
twang's avatar
twang committed
206
207

    @staticmethod
ZCMax's avatar
ZCMax committed
208
209
    def add_sin_difference(boxes1: Tensor,
                           boxes2: Tensor) -> Tuple[Tensor, Tensor]:
twang's avatar
twang committed
210
211
212
213
214
215
216
217
218
        """Convert the rotation difference to difference in sine function.

        Args:
            boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
                and the 7th dimension is rotation dimension.
            boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
                the 7th dimension is rotation dimension.

        Returns:
219
            tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th
twang's avatar
twang committed
220
221
222
223
224
225
226
227
228
229
230
231
232
                dimensions are changed.
        """
        rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
            boxes2[..., 6:7])
        rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[...,
                                                                         6:7])
        boxes1 = torch.cat(
            [boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1)
        boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]],
                           dim=-1)
        return boxes1, boxes2

    @staticmethod
ZCMax's avatar
ZCMax committed
233
234
235
236
237
    def get_direction_target(reg_targets: Tensor,
                             dir_offset: int = 0,
                             dir_limit_offset: float = 0.0,
                             num_bins: int = 2,
                             one_hot: bool = True) -> Tensor:
twang's avatar
twang committed
238
239
240
241
        """Encode direction to 0 ~ num_bins-1.

        Args:
            reg_targets (torch.Tensor): Bbox regression targets.
242
243
244
245
246
247
248
            dir_offset (int, optional): Direction offset. Default to 0.
            dir_limit_offset (float, optional): Offset to set the direction
                range. Default to 0.0.
            num_bins (int, optional): Number of bins to divide 2*PI.
                Default to 2.
            one_hot (bool, optional): Whether to encode as one hot.
                Default to True.
twang's avatar
twang committed
249
250
251
252
253

        Returns:
            torch.Tensor: Encoded direction targets.
        """
        rot_gt = reg_targets[..., 6]
254
255
        offset_rot = limit_period(rot_gt - dir_offset, dir_limit_offset,
                                  2 * np.pi)
twang's avatar
twang committed
256
257
258
259
260
261
262
263
264
265
266
267
268
        dir_cls_targets = torch.floor(offset_rot /
                                      (2 * np.pi / num_bins)).long()
        dir_cls_targets = torch.clamp(dir_cls_targets, min=0, max=num_bins - 1)
        if one_hot:
            dir_targets = torch.zeros(
                *list(dir_cls_targets.shape),
                num_bins,
                dtype=reg_targets.dtype,
                device=dir_cls_targets.device)
            dir_targets.scatter_(dir_cls_targets.unsqueeze(dim=-1).long(), 1.0)
            dir_cls_targets = dir_targets
        return dir_cls_targets

ZCMax's avatar
ZCMax committed
269
270
271
272
273
274
275
276
277
278
279
    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            dir_cls_preds: List[Tensor],
            attr_preds: List[Tensor],
            centernesses: List[Tensor],
            batch_gt_instances_3d: InstanceList,
            batch_gt_instacnes: InstanceList,
            batch_img_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
twang's avatar
twang committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        """Compute loss of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_classes.
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level, each is a 4D-tensor, the channel number is
                num_points * bbox_code_size.
            dir_cls_preds (list[Tensor]): Box scores for direction class
                predictions on each scale level, each is a 4D-tensor,
                the channel number is num_points * 2. (bin = 2)
            attr_preds (list[Tensor]): Attribute scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_points * num_attrs.
            centernesses (list[Tensor]): Centerness for each scale level, each
                is a 4D-tensor, the channel number is num_points * 1.
297
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
ZCMax's avatar
ZCMax committed
298
299
300
301
                gt_instance_3d.  It usually includes ``bboxes_3d``、`
                `labels_3d``、``depths``、``centers_2d`` and attributes.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance.  It usually includes ``bboxes``、``labels``.
302
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
twang's avatar
twang committed
303
                image size, scaling factor, etc.
304
305
306
307
            batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
twang's avatar
twang committed
308
309
310
311
312
313
314
315
316
317

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len(
            attr_preds)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels_3d, bbox_targets_3d, centerness_targets, attr_targets = \
ZCMax's avatar
ZCMax committed
318
319
            self.get_targets(all_level_points, batch_gt_instances_3d,
                             batch_gt_instacnes)
twang's avatar
twang committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388

        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds, dir_cls_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, sum(self.group_reg_dims))
            for bbox_pred in bbox_preds
        ]
        flatten_dir_cls_preds = [
            dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)
            for dir_cls_pred in dir_cls_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_dir_cls_preds = torch.cat(flatten_dir_cls_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels_3d = torch.cat(labels_3d)
        flatten_bbox_targets_3d = torch.cat(bbox_targets_3d)
        flatten_centerness_targets = torch.cat(centerness_targets)

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes
        pos_inds = ((flatten_labels_3d >= 0)
                    & (flatten_labels_3d < bg_class_ind)).nonzero().reshape(-1)
        num_pos = len(pos_inds)

        loss_cls = self.loss_cls(
            flatten_cls_scores,
            flatten_labels_3d,
            avg_factor=num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_dir_cls_preds = flatten_dir_cls_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]

        if self.pred_attrs:
            flatten_attr_preds = [
                attr_pred.permute(0, 2, 3, 1).reshape(-1, self.num_attrs)
                for attr_pred in attr_preds
            ]
            flatten_attr_preds = torch.cat(flatten_attr_preds)
            flatten_attr_targets = torch.cat(attr_targets)
            pos_attr_preds = flatten_attr_preds[pos_inds]

        if num_pos > 0:
            pos_bbox_targets_3d = flatten_bbox_targets_3d[pos_inds]
            pos_centerness_targets = flatten_centerness_targets[pos_inds]
            if self.pred_attrs:
                pos_attr_targets = flatten_attr_targets[pos_inds]
            bbox_weights = pos_centerness_targets.new_ones(
                len(pos_centerness_targets), sum(self.group_reg_dims))
            equal_weights = pos_centerness_targets.new_ones(
                pos_centerness_targets.shape)

            code_weight = self.train_cfg.get('code_weight', None)
            if code_weight:
                assert len(code_weight) == sum(self.group_reg_dims)
                bbox_weights = bbox_weights * bbox_weights.new_tensor(
                    code_weight)

            if self.use_direction_classifier:
                pos_dir_cls_targets = self.get_direction_target(
389
390
391
392
                    pos_bbox_targets_3d,
                    self.dir_offset,
                    self.dir_limit_offset,
                    one_hot=False)
twang's avatar
twang committed
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483

            if self.diff_rad_by_sin:
                pos_bbox_preds, pos_bbox_targets_3d = self.add_sin_difference(
                    pos_bbox_preds, pos_bbox_targets_3d)

            loss_offset = self.loss_bbox(
                pos_bbox_preds[:, :2],
                pos_bbox_targets_3d[:, :2],
                weight=bbox_weights[:, :2],
                avg_factor=equal_weights.sum())
            loss_depth = self.loss_bbox(
                pos_bbox_preds[:, 2],
                pos_bbox_targets_3d[:, 2],
                weight=bbox_weights[:, 2],
                avg_factor=equal_weights.sum())
            loss_size = self.loss_bbox(
                pos_bbox_preds[:, 3:6],
                pos_bbox_targets_3d[:, 3:6],
                weight=bbox_weights[:, 3:6],
                avg_factor=equal_weights.sum())
            loss_rotsin = self.loss_bbox(
                pos_bbox_preds[:, 6],
                pos_bbox_targets_3d[:, 6],
                weight=bbox_weights[:, 6],
                avg_factor=equal_weights.sum())
            loss_velo = None
            if self.pred_velo:
                loss_velo = self.loss_bbox(
                    pos_bbox_preds[:, 7:9],
                    pos_bbox_targets_3d[:, 7:9],
                    weight=bbox_weights[:, 7:9],
                    avg_factor=equal_weights.sum())

            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)

            # direction classification loss
            loss_dir = None
            # TODO: add more check for use_direction_classifier
            if self.use_direction_classifier:
                loss_dir = self.loss_dir(
                    pos_dir_cls_preds,
                    pos_dir_cls_targets,
                    equal_weights,
                    avg_factor=equal_weights.sum())

            # attribute classification loss
            loss_attr = None
            if self.pred_attrs:
                loss_attr = self.loss_attr(
                    pos_attr_preds,
                    pos_attr_targets,
                    pos_centerness_targets,
                    avg_factor=pos_centerness_targets.sum())

        else:
            # need absolute due to possible negative delta x/y
            loss_offset = pos_bbox_preds[:, :2].sum()
            loss_depth = pos_bbox_preds[:, 2].sum()
            loss_size = pos_bbox_preds[:, 3:6].sum()
            loss_rotsin = pos_bbox_preds[:, 6].sum()
            loss_velo = None
            if self.pred_velo:
                loss_velo = pos_bbox_preds[:, 7:9].sum()
            loss_centerness = pos_centerness.sum()
            loss_dir = None
            if self.use_direction_classifier:
                loss_dir = pos_dir_cls_preds.sum()
            loss_attr = None
            if self.pred_attrs:
                loss_attr = pos_attr_preds.sum()

        loss_dict = dict(
            loss_cls=loss_cls,
            loss_offset=loss_offset,
            loss_depth=loss_depth,
            loss_size=loss_size,
            loss_rotsin=loss_rotsin,
            loss_centerness=loss_centerness)

        if loss_velo is not None:
            loss_dict['loss_velo'] = loss_velo

        if loss_dir is not None:
            loss_dict['loss_dir'] = loss_dir

        if loss_attr is not None:
            loss_dict['loss_attr'] = loss_attr

        return loss_dict

ZCMax's avatar
ZCMax committed
484
485
486
487
488
489
490
491
492
    def predict_by_feat(self,
                        cls_scores: List[Tensor],
                        bbox_preds: List[Tensor],
                        dir_cls_preds: List[Tensor],
                        attr_preds: List[Tensor],
                        centernesses: List[Tensor],
                        batch_img_metas: Optional[List[dict]] = None,
                        cfg: OptConfigType = None,
                        rescale: bool = False) -> InstanceList:
twang's avatar
twang committed
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        """Transform network output for a batch into bbox predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level
                Has shape (N, num_points * num_classes, H, W)
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level with shape (N, num_points * 4, H, W)
            dir_cls_preds (list[Tensor]): Box scores for direction class
                predictions on each scale level, each is a 4D-tensor,
                the channel number is num_points * 2. (bin = 2)
            attr_preds (list[Tensor]): Attribute scores for each scale level
                Has shape (N, num_points * num_attrs, H, W)
            centernesses (list[Tensor]): Centerness for each scale level with
                shape (N, num_points * 1, H, W)
507
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
twang's avatar
twang committed
508
                image size, scaling factor, etc.
ZCMax's avatar
ZCMax committed
509
510
511
512
513
            cfg (ConfigDict, optional): Test / postprocessing
                configuration, if None, test_cfg would be used.
                Defaults to None.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.
twang's avatar
twang committed
514
515

        Returns:
ZCMax's avatar
ZCMax committed
516
517
518
519
520
521
522
523
524
            list[:obj:`InstanceData`]: Object detection results of each image
            after the post process. Each item usually contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                  (num_instances, C), where C >= 7.
twang's avatar
twang committed
525
526
527
528
529
530
        """
        assert len(cls_scores) == len(bbox_preds) == len(dir_cls_preds) == \
            len(centernesses) == len(attr_preds)
        num_levels = len(cls_scores)

        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
ZCMax's avatar
ZCMax committed
531
        # TODO: refactor using prior_generator
twang's avatar
twang committed
532
533
534
        mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                      bbox_preds[0].device)
        result_list = []
535
        for img_id in range(len(batch_img_metas)):
ZCMax's avatar
ZCMax committed
536
537
538
539
            img_meta = batch_img_metas[img_id]
            cls_score_list = select_single_mlvl(cls_scores, img_id)
            bbox_pred_list = select_single_mlvl(bbox_preds, img_id)

twang's avatar
twang committed
540
            if self.use_direction_classifier:
ZCMax's avatar
ZCMax committed
541
                dir_cls_pred_list = select_single_mlvl(dir_cls_preds, img_id)
twang's avatar
twang committed
542
543
544
545
546
547
            else:
                dir_cls_pred_list = [
                    cls_scores[i][img_id].new_full(
                        [2, *cls_scores[i][img_id].shape[1:]], 0).detach()
                    for i in range(num_levels)
                ]
ZCMax's avatar
ZCMax committed
548

twang's avatar
twang committed
549
            if self.pred_attrs:
ZCMax's avatar
ZCMax committed
550
                attr_pred_list = select_single_mlvl(attr_preds, img_id)
twang's avatar
twang committed
551
552
553
554
555
556
557
            else:
                attr_pred_list = [
                    cls_scores[i][img_id].new_full(
                        [self.num_attrs, *cls_scores[i][img_id].shape[1:]],
                        self.attr_background_label).detach()
                    for i in range(num_levels)
                ]
ZCMax's avatar
ZCMax committed
558
559
560
561
562
563
564
565
566
567
568
569

            centerness_pred_list = select_single_mlvl(centernesses, img_id)
            results = self._predict_by_feat_single(
                cls_score_list=cls_score_list,
                bbox_pred_list=bbox_pred_list,
                dir_cls_pred_list=dir_cls_pred_list,
                attr_pred_list=attr_pred_list,
                centerness_pred_list=centerness_pred_list,
                mlvl_points=mlvl_points,
                img_meta=img_meta,
                cfg=cfg,
                rescale=rescale)
570
            result_list.append(results)
571
572
        result_list_2d = None
        return result_list, result_list_2d
twang's avatar
twang committed
573

ZCMax's avatar
ZCMax committed
574
575
576
577
578
579
580
581
582
583
    def _predict_by_feat_single(self,
                                cls_score_list: List[Tensor],
                                bbox_pred_list: List[Tensor],
                                dir_cls_pred_list: List[Tensor],
                                attr_pred_list: List[Tensor],
                                centerness_pred_list: List[Tensor],
                                mlvl_points: Tensor,
                                img_meta: dict,
                                cfg: ConfigType,
                                rescale: bool = False) -> InstanceData:
twang's avatar
twang committed
584
585
586
587
588
589
590
591
        """Transform outputs for a single batch item into bbox predictions.

        Args:
            cls_scores (list[Tensor]): Box scores for a single scale level
                Has shape (num_points * num_classes, H, W).
            bbox_preds (list[Tensor]): Box energies / deltas for a single scale
                level with shape (num_points * bbox_code_size, H, W).
            dir_cls_preds (list[Tensor]): Box scores for direction class
592
                predictions on a single scale level with shape
twang's avatar
twang committed
593
594
595
596
597
598
599
                (num_points * 2, H, W)
            attr_preds (list[Tensor]): Attribute scores for each scale level
                Has shape (N, num_points * num_attrs, H, W)
            centernesses (list[Tensor]): Centerness for a single scale level
                with shape (num_points, H, W).
            mlvl_points (list[Tensor]): Box reference for a single scale level
                with shape (num_total_points, 2).
600
            img_meta (dict): Metadata of input image.
601
            cfg (mmengine.Config): Test / postprocessing configuration,
twang's avatar
twang committed
602
603
604
605
                if None, test_cfg would be used.
            rescale (bool): If True, return boxes in original image space.

        Returns:
ZCMax's avatar
ZCMax committed
606
607
608
609
610
611
612
613
614
615
            :obj:`InstanceData`: 3D Detection results of each image
            after the post process.
            Each item usually contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                  (num_instances, C), where C >= 7.
twang's avatar
twang committed
616
        """
617
618
        view = np.array(img_meta['cam2img'])
        scale_factor = img_meta['scale_factor']
twang's avatar
twang committed
619
        cfg = self.test_cfg if cfg is None else cfg
ZCMax's avatar
ZCMax committed
620
        assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_points)
621
        mlvl_centers_2d = []
twang's avatar
twang committed
622
623
624
625
626
627
628
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_dir_scores = []
        mlvl_attr_scores = []
        mlvl_centerness = []

        for cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness, \
ZCMax's avatar
ZCMax committed
629
630
631
                points in zip(cls_score_list, bbox_pred_list,
                              dir_cls_pred_list, attr_pred_list,
                              centerness_pred_list, mlvl_points):
twang's avatar
twang committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            dir_cls_pred = dir_cls_pred.permute(1, 2, 0).reshape(-1, 2)
            dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1]
            attr_pred = attr_pred.permute(1, 2, 0).reshape(-1, self.num_attrs)
            attr_score = torch.max(attr_pred, dim=-1)[1]
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()

            bbox_pred = bbox_pred.permute(1, 2,
                                          0).reshape(-1,
                                                     sum(self.group_reg_dims))
            bbox_pred = bbox_pred[:, :self.bbox_code_size]
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                dir_cls_pred = dir_cls_pred[topk_inds, :]
                centerness = centerness[topk_inds]
                dir_cls_score = dir_cls_score[topk_inds]
                attr_score = attr_score[topk_inds]
            # change the offset to actual center predictions
            bbox_pred[:, :2] = points - bbox_pred[:, :2]
            if rescale:
                bbox_pred[:, :2] /= bbox_pred[:, :2].new_tensor(scale_factor)
            pred_center2d = bbox_pred[:, :3].clone()
661
            bbox_pred[:, :3] = points_img2cam(bbox_pred[:, :3], view)
662
            mlvl_centers_2d.append(pred_center2d)
twang's avatar
twang committed
663
664
665
666
667
668
            mlvl_bboxes.append(bbox_pred)
            mlvl_scores.append(scores)
            mlvl_dir_scores.append(dir_cls_score)
            mlvl_attr_scores.append(attr_score)
            mlvl_centerness.append(centerness)

669
        mlvl_centers_2d = torch.cat(mlvl_centers_2d)
twang's avatar
twang committed
670
671
672
673
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        mlvl_dir_scores = torch.cat(mlvl_dir_scores)

        # change local yaw to global yaw for 3D nms
674
        cam2img = mlvl_centers_2d.new_zeros((4, 4))
Tai-Wang's avatar
Tai-Wang committed
675
        cam2img[:view.shape[0], :view.shape[1]] = \
676
677
            mlvl_centers_2d.new_tensor(view)
        mlvl_bboxes = self.bbox_coder.decode_yaw(mlvl_bboxes, mlvl_centers_2d,
Tai-Wang's avatar
Tai-Wang committed
678
679
680
                                                 mlvl_dir_scores,
                                                 self.dir_offset, cam2img)

681
        mlvl_bboxes_for_nms = xywhr2xyxyr(img_meta['box_type_3d'](
twang's avatar
twang committed
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
            mlvl_bboxes, box_dim=self.bbox_code_size,
            origin=(0.5, 0.5, 0.5)).bev)

        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
        # BG cat_id: num_class
        mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
        mlvl_attr_scores = torch.cat(mlvl_attr_scores)
        mlvl_centerness = torch.cat(mlvl_centerness)
        # no scale_factors in box3d_multiclass_nms
        # Then we multiply it from outside
        mlvl_nms_scores = mlvl_scores * mlvl_centerness[:, None]
        results = box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms,
                                       mlvl_nms_scores, cfg.score_thr,
                                       cfg.max_per_img, cfg, mlvl_dir_scores,
                                       mlvl_attr_scores)
        bboxes, scores, labels, dir_scores, attrs = results
        attrs = attrs.to(labels.dtype)  # change data type to int
701
        bboxes = img_meta['box_type_3d'](
702
            bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
twang's avatar
twang committed
703
        # Note that the predictions use origin (0.5, 0.5, 0.5)
704
        # Due to the ground truth centers_2d are the gravity center of objects
705
706
        # v0.10.0 fix inplace operation to the input tensor of cam_box3d
        # So here we also need to add origin=(0.5, 0.5, 0.5)
twang's avatar
twang committed
707

708
709
710
711
        results = InstanceData()
        results.bboxes_3d = bboxes
        results.scores_3d = scores
        results.labels_3d = labels
ChaimZhu's avatar
ChaimZhu committed
712
713
        if self.pred_attrs and attrs is not None:
            results.attr_labels = attrs
714
715

        return results
twang's avatar
twang committed
716

ZCMax's avatar
ZCMax committed
717
718
719
720
721
722
723
724
    def _get_points_single(self,
                           featmap_size: Tuple[int],
                           stride: int,
                           dtype: torch.dtype,
                           device: torch.device,
                           flatten: bool = False) -> Tensor:
        """Get points of a single scale level.

twang's avatar
twang committed
725
        Args:
ZCMax's avatar
ZCMax committed
726
727
728
729
730
731
            featmap_size (tuple[int]): Single scale level feature map size.
            stride (int): Downsample factor of the feature map.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.
            flatten (bool): Whether to flatten the tensor.
                Defaults to False.
twang's avatar
twang committed
732
733

        Returns:
ZCMax's avatar
ZCMax committed
734
            Tensor: points of each image.
twang's avatar
twang committed
735
736
737
738
739
740
        """
        y, x = super()._get_points_single(featmap_size, stride, dtype, device)
        points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
                             dim=-1) + stride // 2
        return points

ZCMax's avatar
ZCMax committed
741
742
743
744
745
746
    def get_targets(
        self,
        points: List[Tensor],
        batch_gt_instances_3d: InstanceList,
        batch_gt_instances: InstanceList,
    ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]:
twang's avatar
twang committed
747
748
749
750
751
752
        """Compute regression, classification and centerss targets for points
        in multiple images.

        Args:
            points (list[Tensor]): Points of each fpn level, each has shape
                (num_points, 2).
753
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
ZCMax's avatar
ZCMax committed
754
755
756
757
                gt_instance_3d.  It usually includes ``bboxes_3d``、
                ``labels_3d``、``depths``、``centers_2d`` and attributes.
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance.  It usually includes ``bboxes``、``labels``.
twang's avatar
twang committed
758
759

        Returns:
ZCMax's avatar
ZCMax committed
760
761
762
763
764
765
766
767
768
            tuple: Targets of each level.

            - concat_lvl_labels_3d (list[Tensor]): 3D Labels of each level.
            - concat_lvl_bbox_targets_3d (list[Tensor]): 3D BBox targets of
                each level.
            - concat_lvl_centerness_targets (list[Tensor]): Centerness targets
                of each level.
            - concat_lvl_attr_targets (list[Tensor]): Attribute targets of
                each level.
twang's avatar
twang committed
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        """
        assert len(points) == len(self.regress_ranges)
        num_levels = len(points)
        # expand regress ranges to align with points
        expanded_regress_ranges = [
            points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
                points[i]) for i in range(num_levels)
        ]
        # concat all levels points and regress ranges
        concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
        concat_points = torch.cat(points, dim=0)

        # the number of points per img, per lvl
        num_points = [center.size(0) for center in points]

784
785
        if 'attr_labels' not in batch_gt_instances_3d[0]:
            for gt_instances_3d in batch_gt_instances_3d:
ZCMax's avatar
ZCMax committed
786
787
788
789
790
                gt_instances_3d.attr_labels = \
                    gt_instances_3d.labels_3d.new_full(
                        gt_instances_3d.labels_3d.shape,
                        self.attr_background_label
                    )
791

twang's avatar
twang committed
792
793
794
795
        # get labels and bbox_targets of each image
        _, _, labels_3d_list, bbox_targets_3d_list, centerness_targets_list, \
            attr_targets_list = multi_apply(
                self._get_target_single,
796
                batch_gt_instances_3d,
ZCMax's avatar
ZCMax committed
797
                batch_gt_instances,
twang's avatar
twang committed
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
                points=concat_points,
                regress_ranges=concat_regress_ranges,
                num_points_per_lvl=num_points)

        # split to per img, per level
        labels_3d_list = [
            labels_3d.split(num_points, 0) for labels_3d in labels_3d_list
        ]
        bbox_targets_3d_list = [
            bbox_targets_3d.split(num_points, 0)
            for bbox_targets_3d in bbox_targets_3d_list
        ]
        centerness_targets_list = [
            centerness_targets.split(num_points, 0)
            for centerness_targets in centerness_targets_list
        ]
        attr_targets_list = [
            attr_targets.split(num_points, 0)
            for attr_targets in attr_targets_list
        ]

        # concat per level image
        concat_lvl_labels_3d = []
        concat_lvl_bbox_targets_3d = []
        concat_lvl_centerness_targets = []
        concat_lvl_attr_targets = []
        for i in range(num_levels):
            concat_lvl_labels_3d.append(
                torch.cat([labels[i] for labels in labels_3d_list]))
            concat_lvl_centerness_targets.append(
                torch.cat([
                    centerness_targets[i]
                    for centerness_targets in centerness_targets_list
                ]))
            bbox_targets_3d = torch.cat([
                bbox_targets_3d[i] for bbox_targets_3d in bbox_targets_3d_list
            ])
            concat_lvl_attr_targets.append(
                torch.cat(
                    [attr_targets[i] for attr_targets in attr_targets_list]))
            if self.norm_on_bbox:
                bbox_targets_3d[:, :
                                2] = bbox_targets_3d[:, :2] / self.strides[i]
            concat_lvl_bbox_targets_3d.append(bbox_targets_3d)
        return concat_lvl_labels_3d, concat_lvl_bbox_targets_3d, \
            concat_lvl_centerness_targets, concat_lvl_attr_targets

ZCMax's avatar
ZCMax committed
845
846
847
848
    def _get_target_single(
            self, gt_instances_3d: InstanceData, gt_instances: InstanceData,
            points: Tensor, regress_ranges: Tensor,
            num_points_per_lvl: List[int]) -> Tuple[Tensor, ...]:
twang's avatar
twang committed
849
850
        """Compute regression and classification targets for a single image."""
        num_points = points.size(0)
851
        num_gts = len(gt_instances_3d)
ZCMax's avatar
ZCMax committed
852
853
        gt_bboxes = gt_instances.bboxes
        gt_labels = gt_instances.labels
854
855
        gt_bboxes_3d = gt_instances_3d.bboxes_3d
        gt_labels_3d = gt_instances_3d.labels_3d
856
        centers_2d = gt_instances_3d.centers_2d
857
858
859
        depths = gt_instances_3d.depths
        attr_labels = gt_instances_3d.attr_labels

twang's avatar
twang committed
860
861
862
863
864
865
866
867
868
869
870
871
        if not isinstance(gt_bboxes_3d, torch.Tensor):
            gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device)
        if num_gts == 0:
            return gt_labels.new_full((num_points,), self.background_label), \
                   gt_bboxes.new_zeros((num_points, 4)), \
                   gt_labels_3d.new_full(
                       (num_points,), self.background_label), \
                   gt_bboxes_3d.new_zeros((num_points, self.bbox_code_size)), \
                   gt_bboxes_3d.new_zeros((num_points,)), \
                   attr_labels.new_full(
                       (num_points,), self.attr_background_label)

872
873
874
875
        # change orientation to local yaw
        gt_bboxes_3d[..., 6] = -torch.atan2(
            gt_bboxes_3d[..., 0], gt_bboxes_3d[..., 2]) + gt_bboxes_3d[..., 6]

twang's avatar
twang committed
876
877
878
879
880
881
        areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
            gt_bboxes[:, 3] - gt_bboxes[:, 1])
        areas = areas[None].repeat(num_points, 1)
        regress_ranges = regress_ranges[:, None, :].expand(
            num_points, num_gts, 2)
        gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
882
        centers_2d = centers_2d[None].expand(num_points, num_gts, 2)
twang's avatar
twang committed
883
884
885
886
887
888
889
        gt_bboxes_3d = gt_bboxes_3d[None].expand(num_points, num_gts,
                                                 self.bbox_code_size)
        depths = depths[None, :, None].expand(num_points, num_gts, 1)
        xs, ys = points[:, 0], points[:, 1]
        xs = xs[:, None].expand(num_points, num_gts)
        ys = ys[:, None].expand(num_points, num_gts)

890
891
        delta_xs = (xs - centers_2d[..., 0])[..., None]
        delta_ys = (ys - centers_2d[..., 1])[..., None]
twang's avatar
twang committed
892
893
894
895
896
897
898
899
900
901
902
903
904
        bbox_targets_3d = torch.cat(
            (delta_xs, delta_ys, depths, gt_bboxes_3d[..., 3:]), dim=-1)

        left = xs - gt_bboxes[..., 0]
        right = gt_bboxes[..., 2] - xs
        top = ys - gt_bboxes[..., 1]
        bottom = gt_bboxes[..., 3] - ys
        bbox_targets = torch.stack((left, top, right, bottom), -1)

        assert self.center_sampling is True, 'Setting center_sampling to '\
            'False has not been implemented for FCOS3D.'
        # condition1: inside a `center bbox`
        radius = self.center_sample_radius
905
906
        center_xs = centers_2d[..., 0]
        center_ys = centers_2d[..., 1]
twang's avatar
twang committed
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        center_gts = torch.zeros_like(gt_bboxes)
        stride = center_xs.new_zeros(center_xs.shape)

        # project the points on current lvl back to the `original` sizes
        lvl_begin = 0
        for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
            lvl_end = lvl_begin + num_points_lvl
            stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
            lvl_begin = lvl_end

        center_gts[..., 0] = center_xs - stride
        center_gts[..., 1] = center_ys - stride
        center_gts[..., 2] = center_xs + stride
        center_gts[..., 3] = center_ys + stride

        cb_dist_left = xs - center_gts[..., 0]
        cb_dist_right = center_gts[..., 2] - xs
        cb_dist_top = ys - center_gts[..., 1]
        cb_dist_bottom = center_gts[..., 3] - ys
        center_bbox = torch.stack(
            (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
        inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0

        # condition2: limit the regression range for each location
        max_regress_distance = bbox_targets.max(-1)[0]
        inside_regress_range = (
            (max_regress_distance >= regress_ranges[..., 0])
            & (max_regress_distance <= regress_ranges[..., 1]))

        # center-based criterion to deal with ambiguity
        dists = torch.sqrt(torch.sum(bbox_targets_3d[..., :2]**2, dim=-1))
        dists[inside_gt_bbox_mask == 0] = INF
        dists[inside_regress_range == 0] = INF
        min_dist, min_dist_inds = dists.min(dim=1)

        labels = gt_labels[min_dist_inds]
        labels_3d = gt_labels_3d[min_dist_inds]
        attr_labels = attr_labels[min_dist_inds]
        labels[min_dist == INF] = self.background_label  # set as BG
        labels_3d[min_dist == INF] = self.background_label  # set as BG
        attr_labels[min_dist == INF] = self.attr_background_label

        bbox_targets = bbox_targets[range(num_points), min_dist_inds]
        bbox_targets_3d = bbox_targets_3d[range(num_points), min_dist_inds]
        relative_dists = torch.sqrt(
            torch.sum(bbox_targets_3d[..., :2]**2,
                      dim=-1)) / (1.414 * stride[:, 0])
        # [N, 1] / [N, 1]
        centerness_targets = torch.exp(-self.centerness_alpha * relative_dists)

        return labels, bbox_targets, labels_3d, bbox_targets_3d, \
            centerness_targets, attr_labels