voxelnet.py 4.8 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import torch
3
from mmcv.ops import Voxelization
4
from mmcv.runner import force_fp32
zhangwenwei's avatar
zhangwenwei committed
5
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
6

zhangwenwei's avatar
zhangwenwei committed
7
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
zhangwenwei's avatar
zhangwenwei committed
8
from .. import builder
9
from ..builder import DETECTORS
zhangwenwei's avatar
zhangwenwei committed
10
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
14
class VoxelNet(SingleStage3DDetector):
15
    r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
zhangwenwei's avatar
zhangwenwei committed
16
17
18
19
20
21
22
23
24
25

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
26
                 init_cfg=None,
zhangwenwei's avatar
zhangwenwei committed
27
28
29
30
31
32
33
                 pretrained=None):
        super(VoxelNet, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
34
35
            init_cfg=init_cfg,
            pretrained=pretrained)
zhangwenwei's avatar
zhangwenwei committed
36
37
38
39
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

40
    def extract_feat(self, points, img_metas=None):
zhangwenwei's avatar
zhangwenwei committed
41
        """Extract features from points."""
zhangwenwei's avatar
zhangwenwei committed
42
43
44
45
46
47
48
49
50
51
        voxels, num_points, coors = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxels, num_points, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x

    @torch.no_grad()
52
    @force_fp32()
zhangwenwei's avatar
zhangwenwei committed
53
    def voxelize(self, points):
zhangwenwei's avatar
zhangwenwei committed
54
        """Apply hard voxelization to points."""
zhangwenwei's avatar
zhangwenwei committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        voxels, coors, num_points = [], [], []
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return voxels, num_points, coors_batch

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
72
                      img_metas,
zhangwenwei's avatar
zhangwenwei committed
73
74
75
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None):
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """Training forward function.

        Args:
            points (list[torch.Tensor]): Point cloud of each sample.
            img_metas (list[dict]): Meta information of each sample
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes for each sample.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels for
                boxes of each sampole
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                boxes to be ignored. Defaults to None.

        Returns:
            dict: Losses of each branch.
        """
zhangwenwei's avatar
zhangwenwei committed
91
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
92
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
93
        loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
zhangwenwei's avatar
zhangwenwei committed
94
95
96
97
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

zhangwenwei's avatar
zhangwenwei committed
98
    def simple_test(self, points, img_metas, imgs=None, rescale=False):
99
        """Test function without augmentaiton."""
zhangwenwei's avatar
zhangwenwei committed
100
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
101
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
102
103
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
zhangwenwei's avatar
zhangwenwei committed
104
105
106
107
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
108
        return bbox_results
zhangwenwei's avatar
zhangwenwei committed
109

zhangwenwei's avatar
zhangwenwei committed
110
    def aug_test(self, points, img_metas, imgs=None, rescale=False):
111
        """Test function with augmentaiton."""
zhangwenwei's avatar
zhangwenwei committed
112
        feats = self.extract_feats(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
113

zhangwenwei's avatar
zhangwenwei committed
114
115
116
117
118
119
120
121
122
123
124
        # only support aug_test for one sample
        aug_bboxes = []
        for x, img_meta in zip(feats, img_metas):
            outs = self.bbox_head(x)
            bbox_list = self.bbox_head.get_bboxes(
                *outs, img_meta, rescale=rescale)
            bbox_list = [
                dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
                for bboxes, scores, labels in bbox_list
            ]
            aug_bboxes.append(bbox_list[0])
zhangwenwei's avatar
zhangwenwei committed
125

zhangwenwei's avatar
zhangwenwei committed
126
127
128
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
                                            self.bbox_head.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
129

130
        return [merged_bboxes]