ssd_3d_head.py 25 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List, Optional, Tuple, Union

4
5
import torch
from mmcv.ops.nms import batched_nms
6
from mmdet.models.utils import multi_apply
7
8
from mmengine import ConfigDict
from mmengine.structures import InstanceData
9
from torch import Tensor
10
11
from torch.nn import functional as F

12
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
13
14
15
16
from mmdet3d.structures import BaseInstance3DBoxes
from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes,
                                        LiDARInstance3DBoxes,
                                        rotation_3d_in_axis)
17
18
19
from .vote_head import VoteHead


20
@MODELS.register_module()
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class SSD3DHead(VoteHead):
    r"""Bbox head of `3DSSD <https://arxiv.org/abs/2002.10187>`_.

    Args:
        num_classes (int): The number of class.
        bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
            decoding boxes.
        train_cfg (dict): Config for training.
        test_cfg (dict): Config for testing.
        vote_module_cfg (dict): Config of VoteModule for point-wise votes.
        vote_aggregation_cfg (dict): Config of vote aggregation layer.
        pred_layer_cfg (dict): Config of classfication and regression
            prediction layers.
        conv_cfg (dict): Config of convolution in prediction layer.
        norm_cfg (dict): Config of BN in prediction layer.
        act_cfg (dict): Config of activation in prediction layer.
        objectness_loss (dict): Config of objectness loss.
        center_loss (dict): Config of center loss.
        dir_class_loss (dict): Config of direction classification loss.
        dir_res_loss (dict): Config of direction residual regression loss.
        size_res_loss (dict): Config of size residual regression loss.
        corner_loss (dict): Config of bbox corners regression loss.
        vote_loss (dict): Config of candidate points regression loss.
    """

    def __init__(self,
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                 num_classes: int,
                 bbox_coder: Union[ConfigDict, dict],
                 train_cfg: Optional[dict] = None,
                 test_cfg: Optional[dict] = None,
                 vote_module_cfg: Optional[dict] = None,
                 vote_aggregation_cfg: Optional[dict] = None,
                 pred_layer_cfg: Optional[dict] = None,
                 objectness_loss: Optional[dict] = None,
                 center_loss: Optional[dict] = None,
                 dir_class_loss: Optional[dict] = None,
                 dir_res_loss: Optional[dict] = None,
                 size_res_loss: Optional[dict] = None,
                 corner_loss: Optional[dict] = None,
                 vote_loss: Optional[dict] = None,
                 init_cfg: Optional[dict] = None) -> None:
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        super(SSD3DHead, self).__init__(
            num_classes,
            bbox_coder,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            vote_module_cfg=vote_module_cfg,
            vote_aggregation_cfg=vote_aggregation_cfg,
            pred_layer_cfg=pred_layer_cfg,
            objectness_loss=objectness_loss,
            center_loss=center_loss,
            dir_class_loss=dir_class_loss,
            dir_res_loss=dir_res_loss,
            size_class_loss=None,
            size_res_loss=size_res_loss,
76
77
            semantic_loss=None,
            init_cfg=init_cfg)
78
79
        self.corner_loss = MODELS.build(corner_loss)
        self.vote_loss = MODELS.build(vote_loss)
80
81
        self.num_candidates = vote_module_cfg['num_points']

82
    def _get_cls_out_channels(self) -> int:
83
84
85
86
        """Return the channel number of classification outputs."""
        # Class numbers (k) + objectness (1)
        return self.num_classes

87
    def _get_reg_out_channels(self) -> int:
88
89
90
91
92
93
        """Return the channel number of regression outputs."""
        # Bbox classification and regression
        # (center residual (3), size regression (3)
        # heading class+residual (num_dir_bins*2)),
        return 3 + 3 + self.num_dir_bins * 2

94
    def _extract_input(self, feat_dict: dict) -> Tuple:
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        """Extract inputs from features dictionary.

        Args:
            feat_dict (dict): Feature dict from backbone.

        Returns:
            torch.Tensor: Coordinates of input points.
            torch.Tensor: Features of input points.
            torch.Tensor: Indices of input points.
        """
        seed_points = feat_dict['sa_xyz'][-1]
        seed_features = feat_dict['sa_features'][-1]
        seed_indices = feat_dict['sa_indices'][-1]

        return seed_points, seed_features, seed_indices

111
112
113
114
115
116
117
118
119
120
    def loss_by_feat(
            self,
            points: List[torch.Tensor],
            bbox_preds_dict: dict,
            batch_gt_instances_3d: List[InstanceData],
            batch_pts_semantic_mask: Optional[List[torch.Tensor]] = None,
            batch_pts_instance_mask: Optional[List[torch.Tensor]] = None,
            batch_input_metas: List[dict] = None,
            ret_target: bool = False,
            **kwargs) -> dict:
121
122
123
124
        """Compute loss.

        Args:
            points (list[torch.Tensor]): Input points.
125
126
127
128
129
130
131
132
133
134
            bbox_preds_dict (dict): Predictions from forward of vote head.
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instances. It usually includes ``bboxes_3d`` and
                ``labels_3d`` attributes.
            batch_pts_semantic_mask (list[tensor]): Semantic mask
                of points cloud. Defaults to None. Defaults to None.
            batch_pts_semantic_mask (list[tensor]): Instance mask
                of points cloud. Defaults to None. Defaults to None.
            batch_input_metas (list[dict]): Contain pcd and img's meta info.
            ret_target (bool): Return targets or not.  Defaults to False.
135
136
137
138

        Returns:
            dict: Losses of 3DSSD.
        """
139
140
141
142
143

        targets = self.get_targets(points, bbox_preds_dict,
                                   batch_gt_instances_3d,
                                   batch_pts_semantic_mask,
                                   batch_pts_instance_mask)
144
145
146
147
148
149
        (vote_targets, center_targets, size_res_targets, dir_class_targets,
         dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
         vote_mask, positive_mask, negative_mask, centerness_weights,
         box_loss_weights, heading_res_loss_weight) = targets

        # calculate centerness loss
150
151
        centerness_loss = self.loss_objectness(
            bbox_preds_dict['obj_scores'].transpose(2, 1),
152
153
154
155
            centerness_targets,
            weight=centerness_weights)

        # calculate center loss
156
157
        center_loss = self.loss_center(
            bbox_preds_dict['center_offset'],
158
159
160
161
            center_targets,
            weight=box_loss_weights.unsqueeze(-1))

        # calculate direction class loss
162
163
        dir_class_loss = self.loss_dir_class(
            bbox_preds_dict['dir_class'].transpose(1, 2),
164
165
166
167
            dir_class_targets,
            weight=box_loss_weights)

        # calculate direction residual loss
168
169
        dir_res_loss = self.loss_dir_res(
            bbox_preds_dict['dir_res_norm'],
170
171
172
173
            dir_res_targets.unsqueeze(-1).repeat(1, 1, self.num_dir_bins),
            weight=heading_res_loss_weight)

        # calculate size residual loss
174
175
        size_loss = self.loss_size_res(
            bbox_preds_dict['size'],
176
177
178
179
180
            size_res_targets,
            weight=box_loss_weights.unsqueeze(-1))

        # calculate corner loss
        one_hot_dir_class_targets = dir_class_targets.new_zeros(
181
            bbox_preds_dict['dir_class'].shape)
182
183
184
185
        one_hot_dir_class_targets.scatter_(2, dir_class_targets.unsqueeze(-1),
                                           1)
        pred_bbox3d = self.bbox_coder.decode(
            dict(
186
187
                center=bbox_preds_dict['center'],
                dir_res=bbox_preds_dict['dir_res'],
188
                dir_class=one_hot_dir_class_targets,
189
                size=bbox_preds_dict['size']))
190
        pred_bbox3d = pred_bbox3d.reshape(-1, pred_bbox3d.shape[-1])
191
        pred_bbox3d = batch_input_metas[0]['box_type_3d'](
192
193
194
195
196
197
198
199
200
201
202
203
            pred_bbox3d.clone(),
            box_dim=pred_bbox3d.shape[-1],
            with_yaw=self.bbox_coder.with_rot,
            origin=(0.5, 0.5, 0.5))
        pred_corners3d = pred_bbox3d.corners.reshape(-1, 8, 3)
        corner_loss = self.corner_loss(
            pred_corners3d,
            corner3d_targets.reshape(-1, 8, 3),
            weight=box_loss_weights.view(-1, 1, 1))

        # calculate vote loss
        vote_loss = self.vote_loss(
204
            bbox_preds_dict['vote_offset'].transpose(1, 2),
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            vote_targets,
            weight=vote_mask.unsqueeze(-1))

        losses = dict(
            centerness_loss=centerness_loss,
            center_loss=center_loss,
            dir_class_loss=dir_class_loss,
            dir_res_loss=dir_res_loss,
            size_res_loss=size_loss,
            corner_loss=corner_loss,
            vote_loss=vote_loss)

        return losses

219
220
221
222
223
224
225
226
227
    def get_targets(
        self,
        points: List[Tensor],
        bbox_preds_dict: dict = None,
        batch_gt_instances_3d: List[InstanceData] = None,
        batch_pts_semantic_mask: List[torch.Tensor] = None,
        batch_pts_instance_mask: List[torch.Tensor] = None,
    ) -> Tuple[Tensor]:
        """Generate targets of 3DSSD head.
228
229
230

        Args:
            points (list[torch.Tensor]): Points of each batch.
231
232
233
234
235
236
237
238
239
            bbox_preds_dict (dict): Bounding box predictions of
                vote head.  Defaults to None.
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instances. It usually includes ``bboxes`` and ``labels``
                attributes.  Defaults to None.
            batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
                point clouds.  Defaults to None.
            batch_pts_instance_mask (list[tensor]): Instance gt mask for
                point clouds. Defaults to None.
240
241

        Returns:
242
            tuple[torch.Tensor]: Targets of 3DSSD head.
243
        """
244
245
246
247
248
249
250
251
        batch_gt_labels_3d = [
            gt_instances_3d.labels_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
        batch_gt_bboxes_3d = [
            gt_instances_3d.bboxes_3d
            for gt_instances_3d in batch_gt_instances_3d
        ]
252

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        # find empty example
        for index in range(len(batch_gt_labels_3d)):
            if len(batch_gt_labels_3d[index]) == 0:
                fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
                    1, batch_gt_bboxes_3d[index].tensor.shape[-1])
                batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
                    fake_box)
                batch_gt_labels_3d[index] = batch_gt_labels_3d[
                    index].new_zeros(1)

        if batch_pts_semantic_mask is None:
            batch_pts_semantic_mask = [
                None for _ in range(len(batch_gt_labels_3d))
            ]
            batch_pts_instance_mask = [
                None for _ in range(len(batch_gt_labels_3d))
            ]
270
271

        aggregated_points = [
272
273
            bbox_preds_dict['aggregated_points'][i]
            for i in range(len(batch_gt_labels_3d))
274
275
276
        ]

        seed_points = [
277
278
            bbox_preds_dict['seed_points'][i, :self.num_candidates].detach()
            for i in range(len(batch_gt_labels_3d))
279
280
281
282
283
        ]

        (vote_targets, center_targets, size_res_targets, dir_class_targets,
         dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
         vote_mask, positive_mask, negative_mask) = multi_apply(
284
285
286
             self.get_targets_single, points, batch_gt_bboxes_3d,
             batch_gt_labels_3d, batch_pts_semantic_mask,
             batch_pts_instance_mask, aggregated_points, seed_points)
287
288
289
290
291
292
293
294
295
296
297
298
299

        center_targets = torch.stack(center_targets)
        positive_mask = torch.stack(positive_mask)
        negative_mask = torch.stack(negative_mask)
        dir_class_targets = torch.stack(dir_class_targets)
        dir_res_targets = torch.stack(dir_res_targets)
        size_res_targets = torch.stack(size_res_targets)
        mask_targets = torch.stack(mask_targets)
        centerness_targets = torch.stack(centerness_targets).detach()
        corner3d_targets = torch.stack(corner3d_targets)
        vote_targets = torch.stack(vote_targets)
        vote_mask = torch.stack(vote_mask)

300
        center_targets -= bbox_preds_dict['aggregated_points']
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

        centerness_weights = (positive_mask +
                              negative_mask).unsqueeze(-1).repeat(
                                  1, 1, self.num_classes).float()
        centerness_weights = centerness_weights / \
            (centerness_weights.sum() + 1e-6)
        vote_mask = vote_mask / (vote_mask.sum() + 1e-6)

        box_loss_weights = positive_mask / (positive_mask.sum() + 1e-6)

        batch_size, proposal_num = dir_class_targets.shape[:2]
        heading_label_one_hot = dir_class_targets.new_zeros(
            (batch_size, proposal_num, self.num_dir_bins))
        heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
        heading_res_loss_weight = heading_label_one_hot * \
            box_loss_weights.unsqueeze(-1)

        return (vote_targets, center_targets, size_res_targets,
                dir_class_targets, dir_res_targets, mask_targets,
                centerness_targets, corner3d_targets, vote_mask, positive_mask,
                negative_mask, centerness_weights, box_loss_weights,
                heading_res_loss_weight)

    def get_targets_single(self,
325
326
327
328
329
330
331
332
                           points: Tensor,
                           gt_bboxes_3d: BaseInstance3DBoxes,
                           gt_labels_3d: Tensor,
                           pts_semantic_mask: Optional[Tensor] = None,
                           pts_instance_mask: Optional[Tensor] = None,
                           aggregated_points: Optional[Tensor] = None,
                           seed_points: Optional[Tensor] = None,
                           **kwargs):
333
334
335
336
        """Generate targets of ssd3d head for single batch.

        Args:
            points (torch.Tensor): Points of each batch.
337
            gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth
338
339
                boxes of each batch.
            gt_labels_3d (torch.Tensor): Labels of each batch.
340
            pts_semantic_mask (torch.Tensor): Point-wise semantic
341
                label of each batch.
342
            pts_instance_mask (torch.Tensor): Point-wise instance
343
344
345
346
347
348
349
350
351
352
353
354
355
                label of each batch.
            aggregated_points (torch.Tensor): Aggregated points from
                candidate points layer.
            seed_points (torch.Tensor): Seed points of candidate points.

        Returns:
            tuple[torch.Tensor]: Targets of ssd3d head.
        """
        assert self.bbox_coder.with_rot or pts_semantic_mask is not None
        gt_bboxes_3d = gt_bboxes_3d.to(points.device)
        valid_gt = gt_labels_3d != -1
        gt_bboxes_3d = gt_bboxes_3d[valid_gt]
        gt_labels_3d = gt_labels_3d[valid_gt]
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

        # Generate fake GT for empty scene
        if valid_gt.sum() == 0:
            vote_targets = points.new_zeros(self.num_candidates, 3)
            center_targets = points.new_zeros(self.num_candidates, 3)
            size_res_targets = points.new_zeros(self.num_candidates, 3)
            dir_class_targets = points.new_zeros(
                self.num_candidates, dtype=torch.int64)
            dir_res_targets = points.new_zeros(self.num_candidates)
            mask_targets = points.new_zeros(
                self.num_candidates, dtype=torch.int64)
            centerness_targets = points.new_zeros(self.num_candidates,
                                                  self.num_classes)
            corner3d_targets = points.new_zeros(self.num_candidates, 8, 3)
            vote_mask = points.new_zeros(self.num_candidates, dtype=torch.bool)
            positive_mask = points.new_zeros(
                self.num_candidates, dtype=torch.bool)
            negative_mask = points.new_ones(
                self.num_candidates, dtype=torch.bool)
            return (vote_targets, center_targets, size_res_targets,
                    dir_class_targets, dir_res_targets, mask_targets,
                    centerness_targets, corner3d_targets, vote_mask,
                    positive_mask, negative_mask)

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        gt_corner3d = gt_bboxes_3d.corners

        (center_targets, size_targets, dir_class_targets,
         dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d)

        points_mask, assignment = self._assign_targets_by_points_inside(
            gt_bboxes_3d, aggregated_points)

        center_targets = center_targets[assignment]
        size_res_targets = size_targets[assignment]
        mask_targets = gt_labels_3d[assignment]
        dir_class_targets = dir_class_targets[assignment]
        dir_res_targets = dir_res_targets[assignment]
        corner3d_targets = gt_corner3d[assignment]

        top_center_targets = center_targets.clone()
        top_center_targets[:, 2] += size_res_targets[:, 2]
        dist = torch.norm(aggregated_points - top_center_targets, dim=1)
        dist_mask = dist < self.train_cfg.pos_distance_thr
        positive_mask = (points_mask.max(1)[0] > 0) * dist_mask
        negative_mask = (points_mask.max(1)[0] == 0)

        # Centerness loss targets
        canonical_xyz = aggregated_points - center_targets
        if self.bbox_coder.with_rot:
            # TODO: Align points rotation implementation of
            # LiDARInstance3DBoxes and DepthInstance3DBoxes
            canonical_xyz = rotation_3d_in_axis(
                canonical_xyz.unsqueeze(0).transpose(0, 1),
Yezhen Cong's avatar
Yezhen Cong committed
409
410
                -gt_bboxes_3d.yaw[assignment],
                axis=2).squeeze(1)
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
        distance_front = torch.clamp(
            size_res_targets[:, 0] - canonical_xyz[:, 0], min=0)
        distance_back = torch.clamp(
            size_res_targets[:, 0] + canonical_xyz[:, 0], min=0)
        distance_left = torch.clamp(
            size_res_targets[:, 1] - canonical_xyz[:, 1], min=0)
        distance_right = torch.clamp(
            size_res_targets[:, 1] + canonical_xyz[:, 1], min=0)
        distance_top = torch.clamp(
            size_res_targets[:, 2] - canonical_xyz[:, 2], min=0)
        distance_bottom = torch.clamp(
            size_res_targets[:, 2] + canonical_xyz[:, 2], min=0)

        centerness_l = torch.min(distance_front, distance_back) / torch.max(
            distance_front, distance_back)
        centerness_w = torch.min(distance_left, distance_right) / torch.max(
            distance_left, distance_right)
        centerness_h = torch.min(distance_bottom, distance_top) / torch.max(
            distance_bottom, distance_top)
        centerness_targets = torch.clamp(
            centerness_l * centerness_w * centerness_h, min=0)
        centerness_targets = centerness_targets.pow(1 / 3.0)
        centerness_targets = torch.clamp(centerness_targets, min=0, max=1)

        proposal_num = centerness_targets.shape[0]
        one_hot_centerness_targets = centerness_targets.new_zeros(
            (proposal_num, self.num_classes))
        one_hot_centerness_targets.scatter_(1, mask_targets.unsqueeze(-1), 1)
        centerness_targets = centerness_targets.unsqueeze(
            1) * one_hot_centerness_targets

        # Vote loss targets
        enlarged_gt_bboxes_3d = gt_bboxes_3d.enlarged_box(
            self.train_cfg.expand_dims_length)
        enlarged_gt_bboxes_3d.tensor[:, 2] -= self.train_cfg.expand_dims_length
        vote_mask, vote_assignment = self._assign_targets_by_points_inside(
            enlarged_gt_bboxes_3d, seed_points)

        vote_targets = gt_bboxes_3d.gravity_center
        vote_targets = vote_targets[vote_assignment] - seed_points
        vote_mask = vote_mask.max(1)[0] > 0

        return (vote_targets, center_targets, size_res_targets,
                dir_class_targets, dir_res_targets, mask_targets,
                centerness_targets, corner3d_targets, vote_mask, positive_mask,
                negative_mask)

458
459
460
461
    def predict_by_feat(self, points: List[torch.Tensor],
                        bbox_preds_dict: dict, batch_input_metas: List[dict],
                        **kwargs) -> List[InstanceData]:
        """Generate bboxes from vote head predictions.
462
463

        Args:
464
465
466
467
            points (List[torch.Tensor]): Input points of multiple samples.
            bbox_preds_dict (dict): Predictions from vote head.
            batch_input_metas (list[dict]): Each item
                contains the meta information of each sample.
468
469

        Returns:
470
471
472
            list[:obj:`InstanceData`]: List of processed predictions. Each
            InstanceData cantains 3d Bounding boxes and corresponding
            scores and labels.
473
474
        """
        # decode boxes
475
        sem_scores = F.sigmoid(bbox_preds_dict['obj_scores']).transpose(1, 2)
476
        obj_scores = sem_scores.max(-1)[0]
477
        bbox3d = self.bbox_coder.decode(bbox_preds_dict)
478
        batch_size = bbox3d.shape[0]
479
480
        points = torch.stack(points)
        results_list = []
481
        for b in range(batch_size):
482
            temp_results = InstanceData()
483
484
            bbox_selected, score_selected, labels = self.multiclass_nms_single(
                obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
485
                batch_input_metas[b])
486

487
            bbox = batch_input_metas[b]['box_type_3d'](
488
489
490
491
                bbox_selected.clone(),
                box_dim=bbox_selected.shape[-1],
                with_yaw=self.bbox_coder.with_rot)

492
493
494
495
496
497
            temp_results.bboxes_3d = bbox
            temp_results.scores_3d = score_selected
            temp_results.labels_3d = labels
            results_list.append(temp_results)

        return results_list
498

499
500
501
    def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
                              bbox: Tensor, points: Tensor,
                              input_meta: dict) -> Tuple[Tensor]:
502
503
504
505
        """Multi-class nms in single batch.

        Args:
            obj_scores (torch.Tensor): Objectness score of bounding boxes.
506
            sem_scores (torch.Tensor): Semantic class score of bounding boxes.
507
508
509
510
511
512
513
514
515
516
517
            bbox (torch.Tensor): Predicted bounding boxes.
            points (torch.Tensor): Input points.
            input_meta (dict): Point cloud and image's meta info.

        Returns:
            tuple[torch.Tensor]: Bounding boxes, scores and labels.
        """
        bbox = input_meta['box_type_3d'](
            bbox.clone(),
            box_dim=bbox.shape[-1],
            with_yaw=self.bbox_coder.with_rot,
xiliu8006's avatar
xiliu8006 committed
518
            origin=(0.5, 0.5, 0.5))
519

520
        if isinstance(bbox, (LiDARInstance3DBoxes, DepthInstance3DBoxes)):
521
            box_indices = bbox.points_in_boxes_all(points)
522
523
524
525
526
527
528
529
530
531
            nonempty_box_mask = box_indices.T.sum(1) >= 0
        else:
            raise NotImplementedError('Unsupported bbox type!')

        corner3d = bbox.corners
        minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
        minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
        minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]

        bbox_classes = torch.argmax(sem_scores, -1)
532
        nms_keep = batched_nms(
533
534
535
536
            minmax_box3d[nonempty_box_mask][:, [0, 1, 3, 4]],
            obj_scores[nonempty_box_mask], bbox_classes[nonempty_box_mask],
            self.test_cfg.nms_cfg)[1]

537
538
        if nms_keep.shape[0] > self.test_cfg.max_output_num:
            nms_keep = nms_keep[:self.test_cfg.max_output_num]
539
540
541

        # filter empty boxes and boxes with low score
        scores_mask = (obj_scores >= self.test_cfg.score_thr)
Wenhao Wu's avatar
Wenhao Wu committed
542
543
        nonempty_box_inds = torch.nonzero(
            nonempty_box_mask, as_tuple=False).flatten()
544
        nonempty_mask = torch.zeros_like(bbox_classes).scatter(
545
            0, nonempty_box_inds[nms_keep], 1)
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        selected = (nonempty_mask.bool() & scores_mask.bool())

        if self.test_cfg.per_class_proposal:
            bbox_selected, score_selected, labels = [], [], []
            for k in range(sem_scores.shape[-1]):
                bbox_selected.append(bbox[selected].tensor)
                score_selected.append(obj_scores[selected])
                labels.append(
                    torch.zeros_like(bbox_classes[selected]).fill_(k))
            bbox_selected = torch.cat(bbox_selected, 0)
            score_selected = torch.cat(score_selected, 0)
            labels = torch.cat(labels, 0)
        else:
            bbox_selected = bbox[selected].tensor
            score_selected = obj_scores[selected]
            labels = bbox_classes[selected]

        return bbox_selected, score_selected, labels

565
566
    def _assign_targets_by_points_inside(self, bboxes_3d: BaseInstance3DBoxes,
                                         points: Tensor) -> Tuple:
567
568
569
570
571
572
573
574
575
576
        """Compute assignment by checking whether point is inside bbox.

        Args:
            bboxes_3d (BaseInstance3DBoxes): Instance of bounding boxes.
            points (torch.Tensor): Points of a batch.

        Returns:
            tuple[torch.Tensor]: Flags indicating whether each point is
                inside bbox and the index of box where each point are in.
        """
577
        if isinstance(bboxes_3d, (LiDARInstance3DBoxes, DepthInstance3DBoxes)):
578
            points_mask = bboxes_3d.points_in_boxes_all(points)
579
580
581
582
583
            assignment = points_mask.argmax(dim=-1)
        else:
            raise NotImplementedError('Unsupported bbox type!')

        return points_mask, assignment