voxelnet.py 3.79 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
import torch
import torch.nn.functional as F

zhangwenwei's avatar
zhangwenwei committed
4
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
zhangwenwei's avatar
zhangwenwei committed
5
from mmdet3d.ops import Voxelization
zhangwenwei's avatar
zhangwenwei committed
6
from mmdet.models import DETECTORS
zhangwenwei's avatar
zhangwenwei committed
7
from .. import builder
zhangwenwei's avatar
zhangwenwei committed
8
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
9
10


11
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
12
class VoxelNet(SingleStage3DDetector):
zhangwenwei's avatar
zhangwenwei committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(VoxelNet, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
        )
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

zhangwenwei's avatar
zhangwenwei committed
36
    def extract_feat(self, points, img_metas):
zhangwenwei's avatar
zhangwenwei committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        voxels, num_points, coors = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxels, num_points, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x

    @torch.no_grad()
    def voxelize(self, points):
        voxels, coors, num_points = [], [], []
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return voxels, num_points, coors_batch

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
65
                      img_metas,
zhangwenwei's avatar
zhangwenwei committed
66
67
68
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None):
zhangwenwei's avatar
zhangwenwei committed
69
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
70
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
71
        loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
zhangwenwei's avatar
zhangwenwei committed
72
73
74
75
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

zhangwenwei's avatar
zhangwenwei committed
76
77
    def simple_test(self, points, img_metas, imgs=None, rescale=False):
        x = self.extract_feat(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
78
        outs = self.bbox_head(x)
zhangwenwei's avatar
zhangwenwei committed
79
80
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
zhangwenwei's avatar
zhangwenwei committed
81
82
83
84
85
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]
zhangwenwei's avatar
zhangwenwei committed
86

zhangwenwei's avatar
zhangwenwei committed
87
88
    def aug_test(self, points, img_metas, imgs=None, rescale=False):
        feats = self.extract_feats(points, img_metas)
zhangwenwei's avatar
zhangwenwei committed
89

zhangwenwei's avatar
zhangwenwei committed
90
91
92
93
94
95
96
97
98
99
100
        # only support aug_test for one sample
        aug_bboxes = []
        for x, img_meta in zip(feats, img_metas):
            outs = self.bbox_head(x)
            bbox_list = self.bbox_head.get_bboxes(
                *outs, img_meta, rescale=rescale)
            bbox_list = [
                dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
                for bboxes, scores, labels in bbox_list
            ]
            aug_bboxes.append(bbox_list[0])
zhangwenwei's avatar
zhangwenwei committed
101

zhangwenwei's avatar
zhangwenwei committed
102
103
104
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
                                            self.bbox_head.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
105

zhangwenwei's avatar
zhangwenwei committed
106
        return merged_bboxes