"official/projects/yolo/configs/backbones.py" did not exist on "9474c108cc4393b4fbdad86328f2b9d10a17d14a"
voxelnet.py 5.61 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Dict, List, Optional

zhangwenwei's avatar
zhangwenwei committed
4
import torch
5
from mmcv.ops import Voxelization
6
from mmcv.runner import force_fp32
zhangwenwei's avatar
zhangwenwei committed
7
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
8

9
from mmdet3d.core import Det3DDataSample
10
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
11
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
12
13


14
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
15
class VoxelNet(SingleStage3DDetector):
16
    r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
zhangwenwei's avatar
zhangwenwei committed
17
18

    def __init__(self,
19
20
21
22
23
24
25
26
27
28
                 voxel_layer: dict,
                 voxel_encoder: dict,
                 middle_encoder: dict,
                 backbone: dict,
                 neck: Optional[dict] = None,
                 bbox_head: Optional[dict] = None,
                 train_cfg: Optional[dict] = None,
                 test_cfg: Optional[dict] = None,
                 init_cfg: Optional[dict] = None,
                 pretrained: Optional[str] = None) -> None:
zhangwenwei's avatar
zhangwenwei committed
29
30
31
32
33
34
        super(VoxelNet, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
35
36
            init_cfg=init_cfg,
            pretrained=pretrained)
zhangwenwei's avatar
zhangwenwei committed
37
        self.voxel_layer = Voxelization(**voxel_layer)
38
39
        self.voxel_encoder = MODELS.build(voxel_encoder)
        self.middle_encoder = MODELS.build(middle_encoder)
zhangwenwei's avatar
zhangwenwei committed
40

41
    def extract_feat(self, points: List[torch.Tensor]) -> list:
zhangwenwei's avatar
zhangwenwei committed
42
        """Extract features from points."""
zhangwenwei's avatar
zhangwenwei committed
43
44
45
46
47
48
49
50
51
52
        voxels, num_points, coors = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxels, num_points, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x

    @torch.no_grad()
53
    @force_fp32()
54
    def voxelize(self, points: List[torch.Tensor]) -> tuple:
zhangwenwei's avatar
zhangwenwei committed
55
        """Apply hard voxelization to points."""
zhangwenwei's avatar
zhangwenwei committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        voxels, coors, num_points = [], [], []
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return voxels, num_points, coors_batch

71
72
73
74
    def forward_train(self, batch_inputs_dict: Dict[list, torch.Tensor],
                      batch_data_samples: List[Det3DDataSample],
                      **kwargs) -> dict:
        """
75
        Args:
76
77
78
79
80
81
82
83
84
            batch_inputs_dict (dict): The model input dict. It should contain
                ``points`` and ``img`` keys.

                    - points (list[torch.Tensor]): Point cloud of each sample.
                    - imgs (torch.Tensor, optional): Image of each sample.

            batch_data_samples (list[:obj:`Det3DDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance_3d` or `gt_panoptic_seg_3d` or `gt_sem_seg_3d`.
85
86

        Returns:
87
            dict[str, Tensor]: A dictionary of loss components.
88
        """
89
90
91

        x = self.extract_feat(batch_inputs_dict['points'])
        losses = self.bbox_head.forward_train(x, batch_data_samples, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
92
93
        return losses

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    def simple_test(self,
                    batch_inputs_dict: Dict[list, torch.Tensor],
                    batch_input_metas: List[dict],
                    rescale: bool = False) -> list:
        """Test function without test-time augmentation.

        Args:
            batch_inputs_dict (dict): The model input dict. It should contain
                ``points`` and ``img`` keys.

                    - points (list[torch.Tensor]): Point cloud of single
                        sample.
                    - imgs (torch.Tensor, optional): Image of single sample.

            batch_input_metas (list[dict]): List of input information.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the \
                inputs. Each Det3DDataSample usually contain \
                'pred_instances_3d'. And the ``pred_instances_3d`` usually \
                contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instances, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                    contains a tensor with shape (num_instances, 7).
        """
        x = self.extract_feat(batch_inputs_dict['points'])
        bboxes_list = self.bbox_head.simple_test(
            x, batch_input_metas, rescale=rescale)

        # connvert to Det3DDataSample
        results_list = self.postprocess_result(bboxes_list)
        return results_list

    def aug_test(self,
                 aug_batch_inputs_dict: Dict[list, torch.Tensor],
                 aug_batch_input_metas: List[dict],
                 rescale: bool = False) -> list:
137
        """Test function with augmentaiton."""
138
139
140
141
142
        # TODO Refactor this after mmdet update
        feats = self.extract_feats(aug_batch_inputs_dict)
        aug_bboxes = self.bbox_head.aug_test(
            feats, aug_batch_input_metas, rescale=rescale)
        return aug_bboxes