"vscode:/vscode.git/clone" did not exist on "a98720382557ff806a6a9adf1f5731c9652564bf"
vote_head.py 25.3 KB
Newer Older
wuyuefeng's avatar
Votenet  
wuyuefeng committed
1
2
3
import numpy as np
import torch
from mmcv.cnn import ConvModule
zhangwenwei's avatar
zhangwenwei committed
4
5
from torch import nn as nn
from torch.nn import functional as F
wuyuefeng's avatar
Votenet  
wuyuefeng committed
6
7
8
9
10

from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule
11
from mmdet3d.ops import build_sa_module, furthest_point_sample
zhangwenwei's avatar
zhangwenwei committed
12
from mmdet.core import build_bbox_coder, multi_apply
wuyuefeng's avatar
Votenet  
wuyuefeng committed
13
14
15
16
17
from mmdet.models import HEADS


@HEADS.register_module()
class VoteHead(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
18
    r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
19
20
21

    Args:
        num_classes (int): The number of class.
22
        bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
wuyuefeng's avatar
Votenet  
wuyuefeng committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            decoding boxes.
        train_cfg (dict): Config for training.
        test_cfg (dict): Config for testing.
        vote_moudule_cfg (dict): Config of VoteModule for point-wise votes.
        vote_aggregation_cfg (dict): Config of vote aggregation layer.
        feat_channels (tuple[int]): Convolution channels of
            prediction layer.
        conv_cfg (dict): Config of convolution in prediction layer.
        norm_cfg (dict): Config of BN in prediction layer.
        objectness_loss (dict): Config of objectness loss.
        center_loss (dict): Config of center loss.
        dir_class_loss (dict): Config of direction classification loss.
        dir_res_loss (dict): Config of direction residual regression loss.
        size_class_loss (dict): Config of size classification loss.
        size_res_loss (dict): Config of size residual regression loss.
        semantic_loss (dict): Config of point-wise semantic segmentation loss.
    """

    def __init__(self,
                 num_classes,
                 bbox_coder,
                 train_cfg=None,
                 test_cfg=None,
                 vote_moudule_cfg=None,
                 vote_aggregation_cfg=None,
                 feat_channels=(128, 128),
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 dir_class_loss=None,
                 dir_res_loss=None,
                 size_class_loss=None,
                 size_res_loss=None,
                 semantic_loss=None):
        super(VoteHead, self).__init__()
        self.num_classes = num_classes
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_moudule_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.dir_class_loss = build_loss(dir_class_loss)
        self.dir_res_loss = build_loss(dir_res_loss)
        self.size_class_loss = build_loss(size_class_loss)
        self.size_res_loss = build_loss(size_res_loss)
        self.semantic_loss = build_loss(semantic_loss)

        assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[
            'in_channels']

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.num_sizes = self.bbox_coder.num_sizes
        self.num_dir_bins = self.bbox_coder.num_dir_bins

        self.vote_module = VoteModule(**vote_moudule_cfg)
81
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

        prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
        conv_pred_list = list()
        for k in range(len(feat_channels)):
            conv_pred_list.append(
                ConvModule(
                    prev_channel,
                    feat_channels[k],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=True,
                    inplace=True))
            prev_channel = feat_channels[k]
        self.conv_pred = nn.Sequential(*conv_pred_list)

        # Objectness scores (2), center residual (3),
        # heading class+residual (num_dir_bins*2),
        # size class+residual(num_sizes*4)
        conv_out_channel = (2 + 3 + self.num_dir_bins * 2 +
                            self.num_sizes * 4 + num_classes)
        self.conv_pred.add_module('conv_out',
                                  nn.Conv1d(prev_channel, conv_out_channel, 1))

    def init_weights(self):
108
        """Initialize weights of VoteHead."""
wuyuefeng's avatar
Votenet  
wuyuefeng committed
109
110
111
112
113
        pass

    def forward(self, feat_dict, sample_mod):
        """Forward pass.

zhangwenwei's avatar
zhangwenwei committed
114
115
116
117
118
119
120
        Note:
            The forward of VoteHead is devided into 4 steps:

                1. Generate vote_points from seed_points.
                2. Aggregate vote_points.
                3. Predict bbox and score.
                4. Decode predictions.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
121
122

        Args:
wangtai's avatar
wangtai committed
123
124
            feat_dict (dict): Feature dict from backbone.
            sample_mod (str): Sample mode for vote aggregation layer.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
125
                valid modes are "vote", "seed" and "random".
wuyuefeng's avatar
wuyuefeng committed
126
127
128

        Returns:
            dict: Predictions of vote head.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        """
        assert sample_mod in ['vote', 'seed', 'random']

        seed_points = feat_dict['fp_xyz'][-1]
        seed_features = feat_dict['fp_features'][-1]
        seed_indices = feat_dict['fp_indices'][-1]

        # 1. generate vote_points from seed_points
        vote_points, vote_features = self.vote_module(seed_points,
                                                      seed_features)
        results = dict(
            seed_points=seed_points,
            seed_indices=seed_indices,
            vote_points=vote_points,
            vote_features=vote_features)

        # 2. aggregate vote_points
        if sample_mod == 'vote':
            # use fps in vote_aggregation
            sample_indices = None
        elif sample_mod == 'seed':
            # FPS on seed and choose the votes corresponding to the seeds
            sample_indices = furthest_point_sample(seed_points,
                                                   self.num_proposal)
        elif sample_mod == 'random':
            # Random sampling from the votes
            batch_size, num_seed = seed_points.shape[:2]
            sample_indices = seed_points.new_tensor(
                torch.randint(0, num_seed, (batch_size, self.num_proposal)),
                dtype=torch.int32)
        else:
Wenwei Zhang's avatar
Wenwei Zhang committed
160
161
            raise NotImplementedError(
                f'Sample mode {sample_mod} is not supported!')
wuyuefeng's avatar
Votenet  
wuyuefeng committed
162
163
164
165
166
167

        vote_aggregation_ret = self.vote_aggregation(vote_points,
                                                     vote_features,
                                                     sample_indices)
        aggregated_points, features, aggregated_indices = vote_aggregation_ret
        results['aggregated_points'] = aggregated_points
encore-zhou's avatar
encore-zhou committed
168
        results['aggregated_features'] = features
wuyuefeng's avatar
Votenet  
wuyuefeng committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        results['aggregated_indices'] = aggregated_indices

        # 3. predict bbox and score
        predictions = self.conv_pred(features)

        # 4. decode predictions
        decode_res = self.bbox_coder.split_pred(predictions, aggregated_points)
        results.update(decode_res)

        return results

    def loss(self,
             bbox_preds,
             points,
             gt_bboxes_3d,
             gt_labels_3d,
             pts_semantic_mask=None,
             pts_instance_mask=None,
zhangwenwei's avatar
zhangwenwei committed
187
             img_metas=None,
encore-zhou's avatar
encore-zhou committed
188
189
             gt_bboxes_ignore=None,
             ret_target=False):
wuyuefeng's avatar
wuyuefeng committed
190
191
192
193
        """Compute loss.

        Args:
            bbox_preds (dict): Predictions from forward of vote head.
liyinhao's avatar
liyinhao committed
194
            points (list[torch.Tensor]): Input points.
wangtai's avatar
wangtai committed
195
196
197
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each sample.
            gt_labels_3d (list[torch.Tensor]): Labels of each sample.
liyinhao's avatar
liyinhao committed
198
199
200
201
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise
                semantic mask.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise
                instance mask.
zhangwenwei's avatar
zhangwenwei committed
202
            img_metas (list[dict]): Contain pcd and img's meta info.
liyinhao's avatar
liyinhao committed
203
204
            gt_bboxes_ignore (None | list[torch.Tensor]): Specify
                which bounding.
encore-zhou's avatar
encore-zhou committed
205
            ret_target (Bool): Return targets or not.
wuyuefeng's avatar
wuyuefeng committed
206
207
208
209

        Returns:
            dict: Losses of Votenet.
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
                                   pts_semantic_mask, pts_instance_mask,
                                   bbox_preds)
        (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
         dir_class_targets, dir_res_targets, center_targets, mask_targets,
         valid_gt_masks, objectness_targets, objectness_weights,
         box_loss_weights, valid_gt_weights) = targets

        # calculate vote loss
        vote_loss = self.vote_module.get_loss(bbox_preds['seed_points'],
                                              bbox_preds['vote_points'],
                                              bbox_preds['seed_indices'],
                                              vote_target_masks, vote_targets)

        # calculate objectness loss
        objectness_loss = self.objectness_loss(
            bbox_preds['obj_scores'].transpose(2, 1),
            objectness_targets,
            weight=objectness_weights)

        # calculate center loss
        source2target_loss, target2source_loss = self.center_loss(
            bbox_preds['center'],
            center_targets,
            src_weight=box_loss_weights,
            dst_weight=valid_gt_weights)
        center_loss = source2target_loss + target2source_loss

        # calculate direction class loss
        dir_class_loss = self.dir_class_loss(
            bbox_preds['dir_class'].transpose(2, 1),
            dir_class_targets,
            weight=box_loss_weights)

        # calculate direction residual loss
        batch_size, proposal_num = size_class_targets.shape[:2]
        heading_label_one_hot = vote_targets.new_zeros(
            (batch_size, proposal_num, self.num_dir_bins))
        heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
        dir_res_norm = torch.sum(
            bbox_preds['dir_res_norm'] * heading_label_one_hot, -1)
        dir_res_loss = self.dir_res_loss(
            dir_res_norm, dir_res_targets, weight=box_loss_weights)

        # calculate size class loss
        size_class_loss = self.size_class_loss(
            bbox_preds['size_class'].transpose(2, 1),
            size_class_targets,
            weight=box_loss_weights)

        # calculate size residual loss
        one_hot_size_targets = vote_targets.new_zeros(
            (batch_size, proposal_num, self.num_sizes))
        one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1)
        one_hot_size_targets_expand = one_hot_size_targets.unsqueeze(
Wenwei Zhang's avatar
Wenwei Zhang committed
265
            -1).repeat(1, 1, 1, 3).contiguous()
wuyuefeng's avatar
Votenet  
wuyuefeng committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        size_residual_norm = torch.sum(
            bbox_preds['size_res_norm'] * one_hot_size_targets_expand, 2)
        box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
            1, 1, 3)
        size_res_loss = self.size_res_loss(
            size_residual_norm,
            size_res_targets,
            weight=box_loss_weights_expand)

        # calculate semantic loss
        semantic_loss = self.semantic_loss(
            bbox_preds['sem_scores'].transpose(2, 1),
            mask_targets,
            weight=box_loss_weights)

        losses = dict(
            vote_loss=vote_loss,
            objectness_loss=objectness_loss,
            semantic_loss=semantic_loss,
            center_loss=center_loss,
            dir_class_loss=dir_class_loss,
            dir_res_loss=dir_res_loss,
            size_class_loss=size_class_loss,
            size_res_loss=size_res_loss)
encore-zhou's avatar
encore-zhou committed
290
291
292
293

        if ret_target:
            losses['targets'] = targets

wuyuefeng's avatar
Votenet  
wuyuefeng committed
294
295
296
297
298
299
300
301
302
        return losses

    def get_targets(self,
                    points,
                    gt_bboxes_3d,
                    gt_labels_3d,
                    pts_semantic_mask=None,
                    pts_instance_mask=None,
                    bbox_preds=None):
wuyuefeng's avatar
wuyuefeng committed
303
        """Generate targets of vote head.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
304
305

        Args:
liyinhao's avatar
liyinhao committed
306
            points (list[torch.Tensor]): Points of each batch.
wangtai's avatar
wangtai committed
307
308
309
310
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each batch.
            gt_labels_3d (list[torch.Tensor]): Labels of each batch.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
wuyuefeng's avatar
Votenet  
wuyuefeng committed
311
                label of each batch.
wangtai's avatar
wangtai committed
312
            pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
wuyuefeng's avatar
Votenet  
wuyuefeng committed
313
                label of each batch.
wangtai's avatar
wangtai committed
314
            bbox_preds (torch.Tensor): Bounding box predictions of vote head.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
315
316

        Returns:
317
            tuple[torch.Tensor]: Targets of vote head.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
318
319
320
321
322
323
        """
        # find empty example
        valid_gt_masks = list()
        gt_num = list()
        for index in range(len(gt_labels_3d)):
            if len(gt_labels_3d[index]) == 0:
wuyuefeng's avatar
wuyuefeng committed
324
325
326
                fake_box = gt_bboxes_3d[index].tensor.new_zeros(
                    1, gt_bboxes_3d[index].tensor.shape[-1])
                gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
                gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
                valid_gt_masks.append(gt_labels_3d[index].new_zeros(1))
                gt_num.append(1)
            else:
                valid_gt_masks.append(gt_labels_3d[index].new_ones(
                    gt_labels_3d[index].shape))
                gt_num.append(gt_labels_3d[index].shape[0])
        max_gt_num = max(gt_num)

        if pts_semantic_mask is None:
            pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
            pts_instance_mask = [None for i in range(len(gt_labels_3d))]

        aggregated_points = [
            bbox_preds['aggregated_points'][i]
            for i in range(len(gt_labels_3d))
        ]

        (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
         dir_class_targets, dir_res_targets, center_targets, mask_targets,
         objectness_targets, objectness_masks) = multi_apply(
             self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d,
             pts_semantic_mask, pts_instance_mask, aggregated_points)

        # pad targets as original code of votenet.
        for index in range(len(gt_labels_3d)):
            pad_num = max_gt_num - gt_labels_3d[index].shape[0]
            center_targets[index] = F.pad(center_targets[index],
                                          (0, 0, 0, pad_num))
            valid_gt_masks[index] = F.pad(valid_gt_masks[index], (0, pad_num))

        vote_targets = torch.stack(vote_targets)
        vote_target_masks = torch.stack(vote_target_masks)
        center_targets = torch.stack(center_targets)
        valid_gt_masks = torch.stack(valid_gt_masks)

        objectness_targets = torch.stack(objectness_targets)
        objectness_weights = torch.stack(objectness_masks)
        objectness_weights /= (torch.sum(objectness_weights) + 1e-6)
        box_loss_weights = objectness_targets.float() / (
            torch.sum(objectness_targets).float() + 1e-6)
        valid_gt_weights = valid_gt_masks.float() / (
            torch.sum(valid_gt_masks.float()) + 1e-6)
        dir_class_targets = torch.stack(dir_class_targets)
        dir_res_targets = torch.stack(dir_res_targets)
        size_class_targets = torch.stack(size_class_targets)
        size_res_targets = torch.stack(size_res_targets)
        mask_targets = torch.stack(mask_targets)

        return (vote_targets, vote_target_masks, size_class_targets,
                size_res_targets, dir_class_targets, dir_res_targets,
                center_targets, mask_targets, valid_gt_masks,
                objectness_targets, objectness_weights, box_loss_weights,
                valid_gt_weights)

    def get_targets_single(self,
                           points,
                           gt_bboxes_3d,
                           gt_labels_3d,
                           pts_semantic_mask=None,
                           pts_instance_mask=None,
                           aggregated_points=None):
wuyuefeng's avatar
wuyuefeng committed
389
390
391
        """Generate targets of vote head for single batch.

        Args:
liyinhao's avatar
liyinhao committed
392
            points (torch.Tensor): Points of each batch.
wangtai's avatar
wangtai committed
393
394
395
396
            gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
                boxes of each batch.
            gt_labels_3d (torch.Tensor): Labels of each batch.
            pts_semantic_mask (None | torch.Tensor): Point-wise semantic
wuyuefeng's avatar
wuyuefeng committed
397
                label of each batch.
wangtai's avatar
wangtai committed
398
            pts_instance_mask (None | torch.Tensor): Point-wise instance
wuyuefeng's avatar
wuyuefeng committed
399
                label of each batch.
liyinhao's avatar
liyinhao committed
400
            aggregated_points (torch.Tensor): Aggregated points from
wuyuefeng's avatar
wuyuefeng committed
401
402
403
                vote aggregation layer.

        Returns:
404
            tuple[torch.Tensor]: Targets of vote head.
wuyuefeng's avatar
wuyuefeng committed
405
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
406
407
        assert self.bbox_coder.with_rot or pts_semantic_mask is not None

wuyuefeng's avatar
wuyuefeng committed
408
409
        gt_bboxes_3d = gt_bboxes_3d.to(points.device)

wuyuefeng's avatar
Votenet  
wuyuefeng committed
410
411
412
413
414
415
416
        # generate votes target
        num_points = points.shape[0]
        if self.bbox_coder.with_rot:
            vote_targets = points.new_zeros([num_points, 3 * self.gt_per_seed])
            vote_target_masks = points.new_zeros([num_points],
                                                 dtype=torch.long)
            vote_target_idx = points.new_zeros([num_points], dtype=torch.long)
wuyuefeng's avatar
wuyuefeng committed
417
418
            box_indices_all = gt_bboxes_3d.points_in_boxes(points)
            for i in range(gt_labels_3d.shape[0]):
wuyuefeng's avatar
Votenet  
wuyuefeng committed
419
                box_indices = box_indices_all[:, i]
420
421
                indices = torch.nonzero(
                    box_indices, as_tuple=False).squeeze(-1)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
422
423
424
                selected_points = points[indices]
                vote_target_masks[indices] = 1
                vote_targets_tmp = vote_targets[indices]
wuyuefeng's avatar
wuyuefeng committed
425
                votes = gt_bboxes_3d.gravity_center[i].unsqueeze(
wuyuefeng's avatar
Votenet  
wuyuefeng committed
426
427
428
429
                    0) - selected_points[:, :3]

                for j in range(self.gt_per_seed):
                    column_indices = torch.nonzero(
430
431
                        vote_target_idx[indices] == j,
                        as_tuple=False).squeeze(-1)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
                    vote_targets_tmp[column_indices,
                                     int(j * 3):int(j * 3 +
                                                    3)] = votes[column_indices]
                    if j == 0:
                        vote_targets_tmp[column_indices] = votes[
                            column_indices].repeat(1, self.gt_per_seed)

                vote_targets[indices] = vote_targets_tmp
                vote_target_idx[indices] = torch.clamp(
                    vote_target_idx[indices] + 1, max=2)
        elif pts_semantic_mask is not None:
            vote_targets = points.new_zeros([num_points, 3])
            vote_target_masks = points.new_zeros([num_points],
                                                 dtype=torch.long)

            for i in torch.unique(pts_instance_mask):
448
449
                indices = torch.nonzero(
                    pts_instance_mask == i, as_tuple=False).squeeze(-1)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
                if pts_semantic_mask[indices[0]] < self.num_classes:
                    selected_points = points[indices, :3]
                    center = 0.5 * (
                        selected_points.min(0)[0] + selected_points.max(0)[0])
                    vote_targets[indices, :] = center - selected_points
                    vote_target_masks[indices] = 1
            vote_targets = vote_targets.repeat((1, self.gt_per_seed))
        else:
            raise NotImplementedError

        (center_targets, size_class_targets, size_res_targets,
         dir_class_targets,
         dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d)

        proposal_num = aggregated_points.shape[0]
        distance1, _, assignment, _ = chamfer_distance(
            aggregated_points.unsqueeze(0),
            center_targets.unsqueeze(0),
            reduction='none')
        assignment = assignment.squeeze(0)
        euclidean_distance1 = torch.sqrt(distance1.squeeze(0) + 1e-6)

        objectness_targets = points.new_zeros((proposal_num), dtype=torch.long)
        objectness_targets[
            euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1

        objectness_masks = points.new_zeros((proposal_num))
        objectness_masks[
            euclidean_distance1 < self.train_cfg['pos_distance_thr']] = 1.0
        objectness_masks[
            euclidean_distance1 > self.train_cfg['neg_distance_thr']] = 1.0

        dir_class_targets = dir_class_targets[assignment]
        dir_res_targets = dir_res_targets[assignment]
        dir_res_targets /= (np.pi / self.num_dir_bins)
        size_class_targets = size_class_targets[assignment]
        size_res_targets = size_res_targets[assignment]

wuyuefeng's avatar
wuyuefeng committed
488
        one_hot_size_targets = gt_bboxes_3d.tensor.new_zeros(
wuyuefeng's avatar
Votenet  
wuyuefeng committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            (proposal_num, self.num_sizes))
        one_hot_size_targets.scatter_(1, size_class_targets.unsqueeze(-1), 1)
        one_hot_size_targets = one_hot_size_targets.unsqueeze(-1).repeat(
            1, 1, 3)
        mean_sizes = size_res_targets.new_tensor(
            self.bbox_coder.mean_sizes).unsqueeze(0)
        pos_mean_sizes = torch.sum(one_hot_size_targets * mean_sizes, 1)
        size_res_targets /= pos_mean_sizes

        mask_targets = gt_labels_3d[assignment]

        return (vote_targets, vote_target_masks, size_class_targets,
                size_res_targets,
                dir_class_targets, dir_res_targets, center_targets,
                mask_targets.long(), objectness_targets, objectness_masks)

encore-zhou's avatar
encore-zhou committed
505
506
507
508
509
510
    def get_bboxes(self,
                   points,
                   bbox_preds,
                   input_metas,
                   rescale=False,
                   use_nms=True):
wuyuefeng's avatar
wuyuefeng committed
511
512
513
        """Generate bboxes from vote head predictions.

        Args:
liyinhao's avatar
liyinhao committed
514
            points (torch.Tensor): Input points.
wuyuefeng's avatar
wuyuefeng committed
515
            bbox_preds (dict): Predictions from vote head.
wangtai's avatar
wangtai committed
516
            input_metas (list[dict]): Point cloud and image's meta info.
wuyuefeng's avatar
wuyuefeng committed
517
            rescale (bool): Whether to rescale bboxes.
encore-zhou's avatar
encore-zhou committed
518
519
            use_nms (bool): Whether to apply NMS, skip nms postprocessing
                while using vote head in rpn stage.
wuyuefeng's avatar
wuyuefeng committed
520
521

        Returns:
wangtai's avatar
wangtai committed
522
            list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
wuyuefeng's avatar
wuyuefeng committed
523
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
524
525
526
        # decode boxes
        obj_scores = F.softmax(bbox_preds['obj_scores'], dim=-1)[..., -1]
        sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
wuyuefeng's avatar
wuyuefeng committed
527
        bbox3d = self.bbox_coder.decode(bbox_preds)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
528

encore-zhou's avatar
encore-zhou committed
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        if use_nms:
            batch_size = bbox3d.shape[0]
            results = list()
            for b in range(batch_size):
                bbox_selected, score_selected, labels = \
                    self.multiclass_nms_single(obj_scores[b], sem_scores[b],
                                               bbox3d[b], points[b, ..., :3],
                                               input_metas[b])
                bbox = input_metas[b]['box_type_3d'](
                    bbox_selected,
                    box_dim=bbox_selected.shape[-1],
                    with_yaw=self.bbox_coder.with_rot)
                results.append((bbox, score_selected, labels))

            return results
        else:
            return bbox3d
wuyuefeng's avatar
Votenet  
wuyuefeng committed
546

wuyuefeng's avatar
wuyuefeng committed
547
548
    def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
                              input_meta):
wangtai's avatar
wangtai committed
549
        """Multi-class nms in single batch.
wuyuefeng's avatar
wuyuefeng committed
550
551

        Args:
wangtai's avatar
wangtai committed
552
553
554
            obj_scores (torch.Tensor): Objectness score of bounding boxes.
            sem_scores (torch.Tensor): semantic class score of bounding boxes.
            bbox (torch.Tensor): Predicted bounding boxes.
liyinhao's avatar
liyinhao committed
555
            points (torch.Tensor): Input points.
wangtai's avatar
wangtai committed
556
            input_meta (dict): Point cloud and image's meta info.
wuyuefeng's avatar
wuyuefeng committed
557
558

        Returns:
wangtai's avatar
wangtai committed
559
            tuple[torch.Tensor]: Bounding boxes, scores and labels.
wuyuefeng's avatar
wuyuefeng committed
560
        """
wuyuefeng's avatar
wuyuefeng committed
561
562
563
564
565
566
        bbox = input_meta['box_type_3d'](
            bbox,
            box_dim=bbox.shape[-1],
            with_yaw=self.bbox_coder.with_rot,
            origin=(0.5, 0.5, 0.5))
        box_indices = bbox.points_in_boxes(points)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
567

wuyuefeng's avatar
wuyuefeng committed
568
        corner3d = bbox.corners
wuyuefeng's avatar
Votenet  
wuyuefeng committed
569
570
571
572
        minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
        minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
        minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]

wuyuefeng's avatar
wuyuefeng committed
573
574
575
        nonempty_box_mask = box_indices.T.sum(1) > 5

        bbox_classes = torch.argmax(sem_scores, -1)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
576
577
578
579
580
581
582
        nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask],
                                      obj_scores[nonempty_box_mask],
                                      bbox_classes[nonempty_box_mask],
                                      self.test_cfg.nms_thr)

        # filter empty boxes and boxes with low score
        scores_mask = (obj_scores > self.test_cfg.score_thr)
583
584
        nonempty_box_inds = torch.nonzero(
            nonempty_box_mask, as_tuple=False).flatten()
wuyuefeng's avatar
Votenet  
wuyuefeng committed
585
586
587
588
589
590
591
        nonempty_mask = torch.zeros_like(bbox_classes).scatter(
            0, nonempty_box_inds[nms_selected], 1)
        selected = (nonempty_mask.bool() & scores_mask.bool())

        if self.test_cfg.per_class_proposal:
            bbox_selected, score_selected, labels = [], [], []
            for k in range(sem_scores.shape[-1]):
wuyuefeng's avatar
wuyuefeng committed
592
                bbox_selected.append(bbox[selected].tensor)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
593
594
595
596
597
598
599
600
                score_selected.append(obj_scores[selected] *
                                      sem_scores[selected][:, k])
                labels.append(
                    torch.zeros_like(bbox_classes[selected]).fill_(k))
            bbox_selected = torch.cat(bbox_selected, 0)
            score_selected = torch.cat(score_selected, 0)
            labels = torch.cat(labels, 0)
        else:
wuyuefeng's avatar
wuyuefeng committed
601
            bbox_selected = bbox[selected].tensor
wuyuefeng's avatar
Votenet  
wuyuefeng committed
602
603
604
605
            score_selected = obj_scores[selected]
            labels = bbox_classes[selected]

        return bbox_selected, score_selected, labels