voxelnet.py 4.76 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
2
from mmcv.runner import force_fp32
zhangwenwei's avatar
zhangwenwei committed
3
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
4

zhangwenwei's avatar
zhangwenwei committed
5
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
zhangwenwei's avatar
zhangwenwei committed
6
from mmdet3d.ops import Voxelization
zhangwenwei's avatar
zhangwenwei committed
7
from mmdet.models import DETECTORS
zhangwenwei's avatar
zhangwenwei committed
8
from .. import builder
zhangwenwei's avatar
zhangwenwei committed
9
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
10
11


12
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
13
class VoxelNet(SingleStage3DDetector):
14
    r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
zhangwenwei's avatar
zhangwenwei committed
15
16
17
18
19
20
21
22
23
24

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
25
                 init_cfg=None,
zhangwenwei's avatar
zhangwenwei committed
26
27
28
29
30
31
32
                 pretrained=None):
        super(VoxelNet, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
33
34
            init_cfg=init_cfg,
            pretrained=pretrained)
zhangwenwei's avatar
zhangwenwei committed
35
36
37
38
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

zhangwenwei's avatar
zhangwenwei committed
39
    def extract_feat(self, points, img_metas):
zhangwenwei's avatar
zhangwenwei committed
40
        """Extract features from points."""
zhangwenwei's avatar
zhangwenwei committed
41
42
43
44
45
46
47
48
49
50
        voxels, num_points, coors = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxels, num_points, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x

    @torch.no_grad()
51
    @force_fp32()
zhangwenwei's avatar
zhangwenwei committed
52
    def voxelize(self, points):
zhangwenwei's avatar
zhangwenwei committed
53
        """Apply hard voxelization to points."""
zhangwenwei's avatar
zhangwenwei committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        voxels, coors, num_points = [], [], []
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return voxels, num_points, coors_batch

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
71
                      img_metas,
zhangwenwei's avatar
zhangwenwei committed
72
73
74
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        """Training forward function.

        Args:
            points (list[torch.Tensor]): Point cloud of each sample.
            img_metas (list[dict]): Meta information of each sample
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes for each sample.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels for
                boxes of each sampole
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                boxes to be ignored. Defaults to None.

        Returns:
            dict: Losses of each branch.
        """
zhangwenwei's avatar
zhangwenwei committed
90
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
91
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
92
        loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
zhangwenwei's avatar
zhangwenwei committed
93
94
95
96
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

zhangwenwei's avatar
zhangwenwei committed
97
    def simple_test(self, points, img_metas, imgs=None, rescale=False):
98
        """Test function without augmentaiton."""
zhangwenwei's avatar
zhangwenwei committed
99
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
100
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
101
102
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
zhangwenwei's avatar
zhangwenwei committed
103
104
105
106
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
107
        return bbox_results
zhangwenwei's avatar
zhangwenwei committed
108

zhangwenwei's avatar
zhangwenwei committed
109
    def aug_test(self, points, img_metas, imgs=None, rescale=False):
110
        """Test function with augmentaiton."""
zhangwenwei's avatar
zhangwenwei committed
111
        feats = self.extract_feats(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
112

zhangwenwei's avatar
zhangwenwei committed
113
114
115
116
117
118
119
120
121
122
123
        # only support aug_test for one sample
        aug_bboxes = []
        for x, img_meta in zip(feats, img_metas):
            outs = self.bbox_head(x)
            bbox_list = self.bbox_head.get_bboxes(
                *outs, img_meta, rescale=rescale)
            bbox_list = [
                dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
                for bboxes, scores, labels in bbox_list
            ]
            aug_bboxes.append(bbox_list[0])
zhangwenwei's avatar
zhangwenwei committed
124

zhangwenwei's avatar
zhangwenwei committed
125
126
127
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
                                            self.bbox_head.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
128

129
        return [merged_bboxes]