voxelnet.py 4.66 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
zhangwenwei's avatar
zhangwenwei committed
2
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
3

zhangwenwei's avatar
zhangwenwei committed
4
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
zhangwenwei's avatar
zhangwenwei committed
5
from mmdet3d.ops import Voxelization
zhangwenwei's avatar
zhangwenwei committed
6
from mmdet.models import DETECTORS
zhangwenwei's avatar
zhangwenwei committed
7
from .. import builder
zhangwenwei's avatar
zhangwenwei committed
8
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
9
10


11
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
12
class VoxelNet(SingleStage3DDetector):
13
    r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
zhangwenwei's avatar
zhangwenwei committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(VoxelNet, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
        )
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

zhangwenwei's avatar
zhangwenwei committed
37
    def extract_feat(self, points, img_metas):
zhangwenwei's avatar
zhangwenwei committed
38
        """Extract features from points."""
zhangwenwei's avatar
zhangwenwei committed
39
40
41
42
43
44
45
46
47
48
49
        voxels, num_points, coors = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxels, num_points, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x

    @torch.no_grad()
    def voxelize(self, points):
zhangwenwei's avatar
zhangwenwei committed
50
        """Apply hard voxelization to points."""
zhangwenwei's avatar
zhangwenwei committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        voxels, coors, num_points = [], [], []
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return voxels, num_points, coors_batch

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
68
                      img_metas,
zhangwenwei's avatar
zhangwenwei committed
69
70
71
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None):
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        """Training forward function.

        Args:
            points (list[torch.Tensor]): Point cloud of each sample.
            img_metas (list[dict]): Meta information of each sample
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes for each sample.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels for
                boxes of each sampole
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                boxes to be ignored. Defaults to None.

        Returns:
            dict: Losses of each branch.
        """
zhangwenwei's avatar
zhangwenwei committed
87
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
88
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
89
        loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
zhangwenwei's avatar
zhangwenwei committed
90
91
92
93
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

zhangwenwei's avatar
zhangwenwei committed
94
    def simple_test(self, points, img_metas, imgs=None, rescale=False):
95
        """Test function without augmentaiton."""
zhangwenwei's avatar
zhangwenwei committed
96
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
97
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
98
99
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
zhangwenwei's avatar
zhangwenwei committed
100
101
102
103
104
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]
zhangwenwei's avatar
zhangwenwei committed
105

zhangwenwei's avatar
zhangwenwei committed
106
    def aug_test(self, points, img_metas, imgs=None, rescale=False):
107
        """Test function with augmentaiton."""
zhangwenwei's avatar
zhangwenwei committed
108
        feats = self.extract_feats(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
109

zhangwenwei's avatar
zhangwenwei committed
110
111
112
113
114
115
116
117
118
119
120
        # only support aug_test for one sample
        aug_bboxes = []
        for x, img_meta in zip(feats, img_metas):
            outs = self.bbox_head(x)
            bbox_list = self.bbox_head.get_bboxes(
                *outs, img_meta, rescale=rescale)
            bbox_list = [
                dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
                for bboxes, scores, labels in bbox_list
            ]
            aug_bboxes.append(bbox_list[0])
zhangwenwei's avatar
zhangwenwei committed
121

zhangwenwei's avatar
zhangwenwei committed
122
123
124
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
                                            self.bbox_head.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
125

zhangwenwei's avatar
zhangwenwei committed
126
        return merged_bboxes