voxelnet.py 3.89 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
zhangwenwei's avatar
zhangwenwei committed
2
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
3

zhangwenwei's avatar
zhangwenwei committed
4
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
zhangwenwei's avatar
zhangwenwei committed
5
from mmdet3d.ops import Voxelization
zhangwenwei's avatar
zhangwenwei committed
6
from mmdet.models import DETECTORS
zhangwenwei's avatar
zhangwenwei committed
7
from .. import builder
zhangwenwei's avatar
zhangwenwei committed
8
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
9
10


11
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
12
class VoxelNet(SingleStage3DDetector):
zhangwenwei's avatar
zhangwenwei committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(VoxelNet, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
        )
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

zhangwenwei's avatar
zhangwenwei committed
36
    def extract_feat(self, points, img_metas):
zhangwenwei's avatar
zhangwenwei committed
37
        """Extract features from points."""
zhangwenwei's avatar
zhangwenwei committed
38
39
40
41
42
43
44
45
46
47
48
        voxels, num_points, coors = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxels, num_points, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x

    @torch.no_grad()
    def voxelize(self, points):
zhangwenwei's avatar
zhangwenwei committed
49
        """Apply hard voxelization to points."""
zhangwenwei's avatar
zhangwenwei committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        voxels, coors, num_points = [], [], []
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return voxels, num_points, coors_batch

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
67
                      img_metas,
zhangwenwei's avatar
zhangwenwei committed
68
69
70
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None):
zhangwenwei's avatar
zhangwenwei committed
71
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
72
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
73
        loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
zhangwenwei's avatar
zhangwenwei committed
74
75
76
77
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

zhangwenwei's avatar
zhangwenwei committed
78
79
    def simple_test(self, points, img_metas, imgs=None, rescale=False):
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
80
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
81
82
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
zhangwenwei's avatar
zhangwenwei committed
83
84
85
86
87
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]
zhangwenwei's avatar
zhangwenwei committed
88

zhangwenwei's avatar
zhangwenwei committed
89
90
    def aug_test(self, points, img_metas, imgs=None, rescale=False):
        feats = self.extract_feats(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
91

zhangwenwei's avatar
zhangwenwei committed
92
93
94
95
96
97
98
99
100
101
102
        # only support aug_test for one sample
        aug_bboxes = []
        for x, img_meta in zip(feats, img_metas):
            outs = self.bbox_head(x)
            bbox_list = self.bbox_head.get_bboxes(
                *outs, img_meta, rescale=rescale)
            bbox_list = [
                dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
                for bboxes, scores, labels in bbox_list
            ]
            aug_bboxes.append(bbox_list[0])
zhangwenwei's avatar
zhangwenwei committed
103

zhangwenwei's avatar
zhangwenwei committed
104
105
106
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
                                            self.bbox_head.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
107

zhangwenwei's avatar
zhangwenwei committed
108
        return merged_bboxes