primitive_head.py 37.3 KB
Newer Older
encore-zhou's avatar
encore-zhou committed
1
2
3
4
5
6
7
import torch
from mmcv.cnn import ConvModule
from torch import nn as nn
from torch.nn import functional as F

from mmdet3d.models.builder import build_loss
from mmdet3d.models.model_utils import VoteModule
8
from mmdet3d.ops import build_sa_module, furthest_point_sample
encore-zhou's avatar
encore-zhou committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from mmdet.core import multi_apply
from mmdet.models import HEADS


@HEADS.register_module()
class PrimitiveHead(nn.Module):
    r"""Primitive head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.

    Args:
        num_dims (int): The dimension of primitive semantic information.
        num_classes (int): The number of class.
        primitive_mode (str): The mode of primitive module,
            avaliable mode ['z', 'xy', 'line'].
        bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
            decoding boxes.
        train_cfg (dict): Config for training.
        test_cfg (dict): Config for testing.
        vote_moudule_cfg (dict): Config of VoteModule for point-wise votes.
        vote_aggregation_cfg (dict): Config of vote aggregation layer.
        feat_channels (tuple[int]): Convolution channels of
            prediction layer.
        upper_thresh (float): Threshold for line matching.
        surface_thresh (float): Threshold for suface matching.
        conv_cfg (dict): Config of convolution in prediction layer.
        norm_cfg (dict): Config of BN in prediction layer.
        objectness_loss (dict): Config of objectness loss.
        center_loss (dict): Config of center loss.
        semantic_loss (dict): Config of point-wise semantic segmentation loss.
    """

    def __init__(self,
                 num_dims,
                 num_classes,
                 primitive_mode,
                 train_cfg=None,
                 test_cfg=None,
                 vote_moudule_cfg=None,
                 vote_aggregation_cfg=None,
                 feat_channels=(128, 128),
                 upper_thresh=100.0,
                 surface_thresh=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 semantic_reg_loss=None,
                 semantic_cls_loss=None):
        super(PrimitiveHead, self).__init__()
        assert primitive_mode in ['z', 'xy', 'line']
        # The dimension of primitive semantic information.
        self.num_dims = num_dims
        self.num_classes = num_classes
        self.primitive_mode = primitive_mode
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_moudule_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']
        self.upper_thresh = upper_thresh
        self.surface_thresh = surface_thresh

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.semantic_reg_loss = build_loss(semantic_reg_loss)
        self.semantic_cls_loss = build_loss(semantic_cls_loss)

        assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[
            'in_channels']

        # Primitive existence flag prediction
        self.flag_conv = ConvModule(
            vote_moudule_cfg['conv_channels'][-1],
            vote_moudule_cfg['conv_channels'][-1] // 2,
            1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            bias=True,
            inplace=True)
        self.flag_pred = torch.nn.Conv1d(
            vote_moudule_cfg['conv_channels'][-1] // 2, 2, 1)

        self.vote_module = VoteModule(**vote_moudule_cfg)
91
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
encore-zhou's avatar
encore-zhou committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855

        prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
        conv_pred_list = list()
        for k in range(len(feat_channels)):
            conv_pred_list.append(
                ConvModule(
                    prev_channel,
                    feat_channels[k],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=True,
                    inplace=True))
            prev_channel = feat_channels[k]
        self.conv_pred = nn.Sequential(*conv_pred_list)

        conv_out_channel = 3 + num_dims + num_classes
        self.conv_pred.add_module('conv_out',
                                  nn.Conv1d(prev_channel, conv_out_channel, 1))

    def init_weights(self):
        """Initialize weights of VoteHead."""
        pass

    def forward(self, feats_dict, sample_mod):
        """Forward pass.

        Args:
            feats_dict (dict): Feature dict from backbone.
            sample_mod (str): Sample mode for vote aggregation layer.
                valid modes are "vote", "seed" and "random".

        Returns:
            dict: Predictions of primitive head.
        """
        assert sample_mod in ['vote', 'seed', 'random']

        seed_points = feats_dict['fp_xyz_net0'][-1]
        seed_features = feats_dict['hd_feature']
        results = {}

        primitive_flag = self.flag_conv(seed_features)
        primitive_flag = self.flag_pred(primitive_flag)

        results['pred_flag_' + self.primitive_mode] = primitive_flag

        # 1. generate vote_points from seed_points
        vote_points, vote_features = self.vote_module(seed_points,
                                                      seed_features)
        results['vote_' + self.primitive_mode] = vote_points
        results['vote_features_' + self.primitive_mode] = vote_features

        # 2. aggregate vote_points
        if sample_mod == 'vote':
            # use fps in vote_aggregation
            sample_indices = None
        elif sample_mod == 'seed':
            # FPS on seed and choose the votes corresponding to the seeds
            sample_indices = furthest_point_sample(seed_points,
                                                   self.num_proposal)
        elif sample_mod == 'random':
            # Random sampling from the votes
            batch_size, num_seed = seed_points.shape[:2]
            sample_indices = torch.randint(
                0,
                num_seed, (batch_size, self.num_proposal),
                dtype=torch.int32,
                device=seed_points.device)
        else:
            raise NotImplementedError('Unsupported sample mod!')

        vote_aggregation_ret = self.vote_aggregation(vote_points,
                                                     vote_features,
                                                     sample_indices)
        aggregated_points, features, aggregated_indices = vote_aggregation_ret
        results['aggregated_points_' + self.primitive_mode] = aggregated_points
        results['aggregated_features_' + self.primitive_mode] = features
        results['aggregated_indices_' +
                self.primitive_mode] = aggregated_indices

        # 3. predict primitive offsets and semantic information
        predictions = self.conv_pred(features)

        # 4. decode predictions
        decode_ret = self.primitive_decode_scores(predictions,
                                                  aggregated_points)
        results.update(decode_ret)

        center, pred_ind = self.get_primitive_center(
            primitive_flag, decode_ret['center_' + self.primitive_mode])

        results['pred_' + self.primitive_mode + '_ind'] = pred_ind
        results['pred_' + self.primitive_mode + '_center'] = center
        return results

    def loss(self,
             bbox_preds,
             points,
             gt_bboxes_3d,
             gt_labels_3d,
             pts_semantic_mask=None,
             pts_instance_mask=None,
             img_metas=None,
             gt_bboxes_ignore=None):
        """Compute loss.

        Args:
            bbox_preds (dict): Predictions from forward of primitive head.
            points (list[torch.Tensor]): Input points.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each sample.
            gt_labels_3d (list[torch.Tensor]): Labels of each sample.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise
                semantic mask.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise
                instance mask.
            img_metas (list[dict]): Contain pcd and img's meta info.
            gt_bboxes_ignore (None | list[torch.Tensor]): Specify
                which bounding.

        Returns:
            dict: Losses of Primitive Head.
        """
        targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
                                   pts_semantic_mask, pts_instance_mask,
                                   bbox_preds)

        (point_mask, point_offset, gt_primitive_center, gt_primitive_semantic,
         gt_sem_cls_label, gt_primitive_mask) = targets

        losses = {}
        # Compute the loss of primitive existence flag
        pred_flag = bbox_preds['pred_flag_' + self.primitive_mode]
        flag_loss = self.objectness_loss(pred_flag, gt_primitive_mask.long())
        losses['flag_loss_' + self.primitive_mode] = flag_loss

        # calculate vote loss
        vote_loss = self.vote_module.get_loss(
            bbox_preds['seed_points'],
            bbox_preds['vote_' + self.primitive_mode],
            bbox_preds['seed_indices'], point_mask, point_offset)
        losses['vote_loss_' + self.primitive_mode] = vote_loss

        num_proposal = bbox_preds['aggregated_points_' +
                                  self.primitive_mode].shape[1]
        primitive_center = bbox_preds['center_' + self.primitive_mode]
        if self.primitive_mode != 'line':
            primitive_semantic = bbox_preds['size_residuals_' +
                                            self.primitive_mode].contiguous()
        else:
            primitive_semantic = None
        semancitc_scores = bbox_preds['sem_cls_scores_' +
                                      self.primitive_mode].transpose(2, 1)

        gt_primitive_mask = gt_primitive_mask / \
            (gt_primitive_mask.sum() + 1e-6)
        center_loss, size_loss, sem_cls_loss = self.compute_primitive_loss(
            primitive_center, primitive_semantic, semancitc_scores,
            num_proposal, gt_primitive_center, gt_primitive_semantic,
            gt_sem_cls_label, gt_primitive_mask)
        losses['center_loss_' + self.primitive_mode] = center_loss
        losses['size_loss_' + self.primitive_mode] = size_loss
        losses['sem_loss_' + self.primitive_mode] = sem_cls_loss

        return losses

    def get_targets(self,
                    points,
                    gt_bboxes_3d,
                    gt_labels_3d,
                    pts_semantic_mask=None,
                    pts_instance_mask=None,
                    bbox_preds=None):
        """Generate targets of primitive head.

        Args:
            points (list[torch.Tensor]): Points of each batch.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each batch.
            gt_labels_3d (list[torch.Tensor]): Labels of each batch.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
                label of each batch.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
                label of each batch.
            bbox_preds (dict): Predictions from forward of primitive head.

        Returns:
            tuple[torch.Tensor]: Targets of primitive head.
        """
        for index in range(len(gt_labels_3d)):
            if len(gt_labels_3d[index]) == 0:
                fake_box = gt_bboxes_3d[index].tensor.new_zeros(
                    1, gt_bboxes_3d[index].tensor.shape[-1])
                gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
                gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)

        if pts_semantic_mask is None:
            pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
            pts_instance_mask = [None for i in range(len(gt_labels_3d))]

        (point_mask, point_sem,
         point_offset) = multi_apply(self.get_targets_single, points,
                                     gt_bboxes_3d, gt_labels_3d,
                                     pts_semantic_mask, pts_instance_mask)

        point_mask = torch.stack(point_mask)
        point_sem = torch.stack(point_sem)
        point_offset = torch.stack(point_offset)

        batch_size = point_mask.shape[0]
        num_proposal = bbox_preds['aggregated_points_' +
                                  self.primitive_mode].shape[1]
        num_seed = bbox_preds['seed_points'].shape[1]
        seed_inds = bbox_preds['seed_indices'].long()
        seed_inds_expand = seed_inds.view(batch_size, num_seed,
                                          1).repeat(1, 1, 3)
        seed_gt_votes = torch.gather(point_offset, 1, seed_inds_expand)
        seed_gt_votes += bbox_preds['seed_points']
        gt_primitive_center = seed_gt_votes.view(batch_size * num_proposal, 1,
                                                 3)

        seed_inds_expand_sem = seed_inds.view(batch_size, num_seed, 1).repeat(
            1, 1, 4 + self.num_dims)
        seed_gt_sem = torch.gather(point_sem, 1, seed_inds_expand_sem)
        gt_primitive_semantic = seed_gt_sem[:, :, 3:3 + self.num_dims].view(
            batch_size * num_proposal, 1, self.num_dims).contiguous()

        gt_sem_cls_label = seed_gt_sem[:, :, -1].long()

        gt_votes_mask = torch.gather(point_mask, 1, seed_inds)

        return (point_mask, point_offset, gt_primitive_center,
                gt_primitive_semantic, gt_sem_cls_label, gt_votes_mask)

    def get_targets_single(self,
                           points,
                           gt_bboxes_3d,
                           gt_labels_3d,
                           pts_semantic_mask=None,
                           pts_instance_mask=None):
        """Generate targets of primitive head for single batch.

        Args:
            points (torch.Tensor): Points of each batch.
            gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
                boxes of each batch.
            gt_labels_3d (torch.Tensor): Labels of each batch.
            pts_semantic_mask (None | torch.Tensor): Point-wise semantic
                label of each batch.
            pts_instance_mask (None | torch.Tensor): Point-wise instance
                label of each batch.

        Returns:
            tuple[torch.Tensor]: Targets of primitive head.
        """
        gt_bboxes_3d = gt_bboxes_3d.to(points.device)
        num_points = points.shape[0]

        point_mask = points.new_zeros(num_points)
        # Offset to the primitive center
        point_offset = points.new_zeros([num_points, 3])
        # Semantic information of primitive center
        point_sem = points.new_zeros([num_points, 3 + self.num_dims + 1])

        instance_flag = torch.nonzero(
            pts_semantic_mask != self.num_classes).squeeze(1)
        instance_labels = pts_instance_mask[instance_flag].unique()

        for i, i_instance in enumerate(instance_labels):
            indices = instance_flag[pts_instance_mask[instance_flag] ==
                                    i_instance]
            coords = points[indices, :3]
            cur_cls_label = pts_semantic_mask[indices][0]

            # Bbox Corners
            cur_corners = gt_bboxes_3d.corners[i]
            xmin, ymin, zmin = cur_corners.min(0)[0]
            xmax, ymax, zmax = cur_corners.max(0)[0]

            plane_lower_temp = points.new_tensor(
                [0, 0, 1, -cur_corners[7, -1]])
            upper_points = cur_corners[[1, 2, 5, 6]]
            refined_distance = (upper_points * plane_lower_temp[:3]).sum(dim=1)

            if self.check_horizon(upper_points) and \
                    plane_lower_temp[0] + plane_lower_temp[1] < \
                    self.train_cfg['lower_thresh']:
                plane_lower = points.new_tensor(
                    [0, 0, 1, plane_lower_temp[-1]])
                plane_upper = points.new_tensor(
                    [0, 0, 1, -torch.mean(refined_distance)])
            else:
                raise NotImplementedError('Only horizontal plane is support!')

            if self.check_dist(plane_upper, upper_points) is False:
                raise NotImplementedError(
                    'Mean distance to plane should be lower than thresh!')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_lower, coords)

            # Get lower four lines
            if self.primitive_mode == 'line':
                point2line_matching = self.match_point2line(
                    coords[selected], xmin, xmax, ymin, ymax)

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(point_mask,
                                                        point_offset,
                                                        point_sem,
                                                        coords[selected],
                                                        indices[selected],
                                                        cur_cls_label,
                                                        point2line_matching,
                                                        cur_corners,
                                                        [1, 1, 0, 0])

            # Set the surface labels here
            if self.primitive_mode == 'z' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners)

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_upper, coords)

            # Get upper four lines
            if self.primitive_mode == 'line':
                point2line_matching = self.match_point2line(
                    coords[selected], xmin, xmax, ymin, ymax)

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(point_mask,
                                                        point_offset,
                                                        point_sem,
                                                        coords[selected],
                                                        indices[selected],
                                                        cur_cls_label,
                                                        point2line_matching,
                                                        cur_corners,
                                                        [1, 1, 0, 0])

            if self.primitive_mode == 'z' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners)

            # Get left two lines
            plane_left_temp = self._get_plane_fomulation(
                cur_corners[2] - cur_corners[3],
                cur_corners[3] - cur_corners[0], cur_corners[0])

            right_points = cur_corners[[4, 5, 7, 6]]
            plane_left_temp /= torch.norm(plane_left_temp[:3])
            refined_distance = (right_points * plane_left_temp[:3]).sum(dim=1)

            if plane_left_temp[2] < self.train_cfg['lower_thresh']:
                plane_left = plane_left_temp
                plane_right = points.new_tensor([
                    plane_left_temp[0], plane_left_temp[1], plane_left_temp[2],
                    -refined_distance.mean()
                ])
            else:
                raise NotImplementedError(
                    'Normal vector of the plane should be horizontal!')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_left, coords)

            # Get upper four lines
            if self.primitive_mode == 'line':
                _, _, line_sel1, line_sel2 = self.match_point2line(
                    coords[selected], xmin, xmax, ymin, ymax)
                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(point_mask,
                                                        point_offset,
                                                        point_sem,
                                                        coords[selected],
                                                        indices[selected],
                                                        cur_cls_label,
                                                        [line_sel1, line_sel2],
                                                        cur_corners,
                                                        [2, 2])

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners)

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_right, coords)

            if self.primitive_mode == 'line':
                _, _, line_sel1, line_sel2 = self.match_point2line(
                    coords[selected], xmin, xmax, ymin, ymax)

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_line_targets(point_mask,
                                                        point_offset,
                                                        point_sem,
                                                        coords[selected],
                                                        indices[selected],
                                                        cur_cls_label,
                                                        [line_sel1, line_sel2],
                                                        cur_corners,
                                                        [2, 2])

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners)

            plane_front_temp = self._get_plane_fomulation(
                cur_corners[0] - cur_corners[4],
                cur_corners[4] - cur_corners[5], cur_corners[5])

            back_points = cur_corners[[3, 2, 7, 6]]
            plane_front_temp /= torch.norm(plane_front_temp[:3])
            refined_distance = (back_points * plane_front_temp[:3]).sum(dim=1)

            if plane_front_temp[2] < self.train_cfg['lower_thresh']:
                plane_front = plane_front_temp
                plane_back = points.new_tensor([
                    plane_front_temp[0], plane_front_temp[1],
                    plane_front_temp[2], -torch.mean(refined_distance)
                ])
            else:
                raise NotImplementedError(
                    'Normal vector of the plane should be horizontal!')

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_front, coords)

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    (point2plane_dist[selected]).var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners)

            # Get the boundary points here
            point2plane_dist, selected = self.match_point2plane(
                plane_back, coords)

            if self.primitive_mode == 'xy' and \
                    selected.sum() > self.train_cfg['num_point'] and \
                    point2plane_dist[selected].var() < \
                    self.train_cfg['var_thresh']:

                point_mask, point_offset, point_sem = \
                    self._assign_primitive_surface_targets(point_mask,
                                                           point_offset,
                                                           point_sem,
                                                           coords[selected],
                                                           indices[selected],
                                                           cur_cls_label,
                                                           cur_corners)

        return (point_mask, point_sem, point_offset)

    def primitive_decode_scores(self, predictions, aggregated_points):
        """Decode predicted parts to primitive head.

        Args:
            predictions (torch.Tensor): primitive pridictions of each batch.
            aggregated_points (torch.Tensor): The aggregated points
                of vote stage.

        Returns:
            Dict: Predictions of primitive head, including center,
                semantic size and semantic scores.
        """

        ret_dict = {}
        pred_transposed = predictions.transpose(2, 1)

        center = aggregated_points + pred_transposed[:, :, 0:3]
        ret_dict['center_' + self.primitive_mode] = center

        if self.primitive_mode in ['z', 'xy']:
            ret_dict['size_residuals_' + self.primitive_mode] = \
                pred_transposed[:, :, 3:3 + self.num_dims]

        ret_dict['sem_cls_scores_' + self.primitive_mode] = \
            pred_transposed[:, :, 3 + self.num_dims:]

        return ret_dict

    def check_horizon(self, points):
        """Check whether is a horizontal plane.

        Args:
            points (torch.Tensor): Points of input.

        Returns:
            Bool: Flag of result.
        """
        return (points[0][-1] == points[1][-1]) and \
               (points[1][-1] == points[2][-1]) and \
               (points[2][-1] == points[3][-1])

    def check_dist(self, plane_equ, points):
        """Whether the mean of points to plane distance is lower than thresh.

        Args:
            plane_equ (torch.Tensor): Plane to be checked.
            points (torch.Tensor): Points to be checked.

        Returns:
            Tuple: Flag of result.
        """
        return (points[:, 2] +
                plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh']

    def match_point2line(self, points, xmin, xmax, ymin, ymax):
        """Match points to corresponding line.

        Args:
            points (torch.Tensor): Points of input.
            xmin (float): Min of X-axis.
            xmax (float): Max of X-axis.
            ymin (float): Min of Y-axis.
            ymax (float): Max of Y-axis.

        Returns:
            Tuple: Flag of matching correspondence.
        """
        sel1 = torch.abs(points[:, 0] - xmin) < self.train_cfg['line_thresh']
        sel2 = torch.abs(points[:, 0] - xmax) < self.train_cfg['line_thresh']
        sel3 = torch.abs(points[:, 1] - ymin) < self.train_cfg['line_thresh']
        sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh']
        return sel1, sel2, sel3, sel4

    def match_point2plane(self, plane, points):
        """Match points to plane.

        Args:
            plane (torch.Tensor): Equation of the plane.
            points (torch.Tensor): Points of input.

        Returns:
            Tuple: Distance of each point to the plane and
                flag of matching correspondence.
        """
        point2plane_dist = torch.abs((points * plane[:3]).sum(dim=1) +
                                     plane[-1])
        min_dist = point2plane_dist.min()
        selected = torch.abs(point2plane_dist -
                             min_dist) < self.train_cfg['dist_thresh']
        return point2plane_dist, selected

    def compute_primitive_loss(self, primitive_center, primitive_semantic,
                               semantic_scores, num_proposal,
                               gt_primitive_center, gt_primitive_semantic,
                               gt_sem_cls_label, gt_primitive_mask):
        """Compute loss of primitive module.

        Args:
            primitive_center (torch.Tensor): Pridictions of primitive center.
            primitive_semantic (torch.Tensor): Pridictions of primitive
                semantic.
            semantic_scores (torch.Tensor): Pridictions of primitive
                semantic scores.
            num_proposal (int): The number of primitive proposal.
            gt_primitive_center (torch.Tensor): Ground truth of
                primitive center.
            gt_votes_sem (torch.Tensor): Ground truth of primitive semantic.
            gt_sem_cls_label (torch.Tensor): Ground truth of primitive
                semantic class.
            gt_primitive_mask (torch.Tensor): Ground truth of primitive mask.

        Returns:
            Tuple: Loss of primitive module.
        """
        batch_size = primitive_center.shape[0]
        vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1,
                                                 3)

        center_loss = self.center_loss(
            vote_xyz_reshape,
            gt_primitive_center,
            dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1]

        if self.primitive_mode != 'line':
            size_xyz_reshape = primitive_semantic.view(
                batch_size * num_proposal, -1, self.num_dims).contiguous()
            size_loss = self.semantic_reg_loss(
                size_xyz_reshape,
                gt_primitive_semantic,
                dst_weight=gt_primitive_mask.view(batch_size * num_proposal,
                                                  1))[1]
        else:
            size_loss = center_loss.new_tensor(0.0)

        # Semantic cls loss
        sem_cls_loss = self.semantic_cls_loss(
            semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask)

        return center_loss, size_loss, sem_cls_loss

    def get_primitive_center(self, pred_flag, center):
        """Generate primitive center from predictions.

        Args:
            pred_flag (torch.Tensor): Scores of primitive center.
            center (torch.Tensor): Pridictions of primitive center.

        Returns:
            Tuple: Primitive center and the prediction indices.
        """
        ind_normal = F.softmax(pred_flag, dim=1)
        pred_indices = (ind_normal[:, 1, :] >
                        self.surface_thresh).detach().float()
        selected = (ind_normal[:, 1, :] <=
                    self.surface_thresh).detach().float()
        offset = torch.ones_like(center) * self.upper_thresh
        center = center + offset * selected.unsqueeze(-1)
        return center, pred_indices

    def _assign_primitive_line_targets(self, point_mask, point_offset,
                                       point_sem, coords, indices, cls_label,
                                       point2line_matching, corners,
                                       center_axises):
        """Generate targets of line primitive.

        Args:
            point_mask (torch.Tensor): Tensor to store the ground
                truth of mask.
            point_offset (torch.Tensor): Tensor to store the ground
                truth of offset.
            point_sem (torch.Tensor): Tensor to store the ground
                truth of semantic.
            coords (torch.Tensor): The selected points.
            indices (torch.Tensor): Indices of the selected points.
            cls_label (int): Class label of the ground truth bounding box.
            point2line_matching (torch.Tensor): Flag indicate that
                matching line of each point.
            corners (torch.Tensor): Corners of the ground truth bounding box.
            center_axises (list[int]): Indicate in which axis the line center
                should be refined.

        Returns:
            Tuple: Targets of the line primitive.
        """
        for line_select, center_axis in zip(point2line_matching,
                                            center_axises):
            if line_select.sum() > self.train_cfg['num_point_line']:
                point_mask[indices[line_select]] = 1.0
                line_center = coords[line_select].mean(dim=0)
                line_center[center_axis] = corners[:, center_axis].mean()
                point_offset[indices[line_select]] = \
                    line_center - coords[line_select]
                point_sem[indices[line_select]] = \
                    point_sem.new_tensor([line_center[0], line_center[1],
                                          line_center[2], cls_label])
        return point_mask, point_offset, point_sem

    def _assign_primitive_surface_targets(self, point_mask, point_offset,
                                          point_sem, coords, indices,
                                          cls_label, corners):
        """Generate targets for primitive z and primitive xy.

        Args:
            point_mask (torch.Tensor): Tensor to store the ground
                truth of mask.
            point_offset (torch.Tensor): Tensor to store the ground
                truth of offset.
            point_sem (torch.Tensor): Tensor to store the ground
                truth of semantic.
            coords (torch.Tensor): The selected points.
            indices (torch.Tensor): Indices of the selected points.
            cls_label (int): Class label of the ground truth bounding box.
            corners (torch.Tensor): Corners of the ground truth bounding box.

        Returns:
            Tuple: Targets of the center primitive.
        """
        point_mask[indices] = 1.0
        if self.primitive_mode == 'z':
            center = point_mask.new_tensor([
                corners[:, 0].mean(), corners[:, 1].mean(), coords[:,
                                                                   2].mean()
            ])
            point_sem[indices] = point_sem.new_tensor([
                center[0], center[1], center[2],
                corners[:, 0].max() - corners[:, 0].min(),
                corners[:, 1].max() - corners[:, 1].min(), cls_label
            ])
        elif self.primitive_mode == 'xy':
            center = point_mask.new_tensor([
                coords[:, 0].mean(), coords[:, 1].mean(), corners[:, 2].mean()
            ])
            point_sem[indices] = point_sem.new_tensor([
                center[0], center[1], center[2],
                corners[:, 2].max() - corners[:, 2].min(), cls_label
            ])
        point_offset[indices] = center - coords
        return point_mask, point_offset, point_sem

    def _get_plane_fomulation(self, vector1, vector2, point):
        """Compute the equation of the plane.

        Args:
            vector1 (torch.Tensor): Parallel vector of the plane.
            vector2 (torch.Tensor): Parallel vector of the plane.
            point (torch.Tensor): Point on the plane.

        Returns:
            torch.Tensor: Equation of the plane.
        """
        surface_norm = torch.cross(vector1, vector2)
        surface_dis = -torch.dot(surface_norm, point)
        plane = point.new_tensor(
            [surface_norm[0], surface_norm[1], surface_norm[2], surface_dis])
        return plane