"examples/offline_inference_mlpspeculator.py" did not exist on "c6fb83b283e6f62be6272bc9798052dc17d7b280"
imvoxelnet.py 8.87 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
Tai-Wang's avatar
Tai-Wang committed
2
3
from typing import List, Tuple, Union

4
5
import torch

Tai-Wang's avatar
Tai-Wang committed
6
7
from mmdet3d.core import Det3DDataSample, InstanceList, build_prior_generator
from mmdet3d.core.utils import ConfigType, OptConfigType, SampleList
8
from mmdet3d.models.fusion_layers.point_fusion import point_sample
9
from mmdet3d.registry import MODELS
10
11
12
from mmdet.models.detectors import BaseDetector


13
@MODELS.register_module()
14
class ImVoxelNet(BaseDetector):
Tai-Wang's avatar
Tai-Wang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.

    Args:
        backbone (:obj:`ConfigDict` or dict): The backbone config.
        neck (:obj:`ConfigDict` or dict): The neck config.
        neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
        bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
        n_voxels (list): Number of voxels along x, y, z axis.
        anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
            config.
        train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
            training hyper-parameters. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
            hyper-parameters. Defaults to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
                ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
        init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
            config. Defaults to None.
    """
35
36

    def __init__(self,
Tai-Wang's avatar
Tai-Wang committed
37
38
39
40
41
42
43
44
45
46
47
48
                 backbone: ConfigType,
                 neck: ConfigType,
                 neck_3d: ConfigType,
                 bbox_head: ConfigType,
                 n_voxels: List,
                 anchor_generator: ConfigType,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)
49
50
51
        self.backbone = MODELS.build(backbone)
        self.neck = MODELS.build(neck)
        self.neck_3d = MODELS.build(neck_3d)
52
53
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
54
        self.bbox_head = MODELS.build(bbox_head)
55
        self.n_voxels = n_voxels
56
        self.anchor_generator = build_prior_generator(anchor_generator)
57
58
59
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

Tai-Wang's avatar
Tai-Wang committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def convert_to_datasample(self, results_list: InstanceList) -> SampleList:
        """Convert results list to `Det3DDataSample`.

        Args:
            results_list (list[:obj:`InstanceData`]): 3D Detection results of
                each image.

        Returns:
            list[:obj:`Det3DDataSample`]: 3D Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
        """
        out_results_list = []
        for i in range(len(results_list)):
            result = Det3DDataSample()
            result.pred_instances_3d = results_list[i]
            out_results_list.append(result)
        return out_results_list

    def extract_feat(self, batch_inputs_dict: dict,
                     batch_data_samples: SampleList):
89
90
91
        """Extract 3d features from the backbone -> fpn -> 3d projection.

        Args:
Tai-Wang's avatar
Tai-Wang committed
92
93
94
95
96
97
98
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
99
100
101
102

        Returns:
            torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
        """
Tai-Wang's avatar
Tai-Wang committed
103
104
105
106
        img = batch_inputs_dict['imgs']
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
107
108
109
110
111
        x = self.backbone(img)
        x = self.neck(x)[0]
        points = self.anchor_generator.grid_anchors(
            [self.n_voxels[::-1]], device=img.device)[0][:, :3]
        volumes = []
Tai-Wang's avatar
Tai-Wang committed
112
        for feature, img_meta in zip(x, batch_img_metas):
113
114
115
116
117
118
119
            img_scale_factor = (
                points.new_tensor(img_meta['scale_factor'][:2])
                if 'scale_factor' in img_meta.keys() else 1)
            img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
            img_crop_offset = (
                points.new_tensor(img_meta['img_crop_offset'])
                if 'img_crop_offset' in img_meta.keys() else 0)
Tai-Wang's avatar
Tai-Wang committed
120
            lidar2img = points.new_tensor(img_meta['lidar2img'])
121
122
123
124
            volume = point_sample(
                img_meta,
                img_features=feature[None, ...],
                points=points,
Tai-Wang's avatar
Tai-Wang committed
125
                proj_mat=lidar2img,
126
                coord_type='LIDAR',
127
128
129
130
131
132
133
134
135
136
137
138
                img_scale_factor=img_scale_factor,
                img_crop_offset=img_crop_offset,
                img_flip=img_flip,
                img_pad_shape=img.shape[-2:],
                img_shape=img_meta['img_shape'][:2],
                aligned=False)
            volumes.append(
                volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
        x = torch.stack(volumes)
        x = self.neck_3d(x)
        return x

Tai-Wang's avatar
Tai-Wang committed
139
140
141
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.
142
143

        Args:
Tai-Wang's avatar
Tai-Wang committed
144
145
146
147
148
149
150
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
151
152

        Returns:
Tai-Wang's avatar
Tai-Wang committed
153
            dict: A dictionary of loss components.
154
        """
Tai-Wang's avatar
Tai-Wang committed
155
156
157

        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
158
159
        return losses

Tai-Wang's avatar
Tai-Wang committed
160
161
162
163
    def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                **kwargs) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
164
165

        Args:
Tai-Wang's avatar
Tai-Wang committed
166
167
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.
168

Tai-Wang's avatar
Tai-Wang committed
169
                    - imgs (torch.Tensor, optional): Image of each sample.
170

Tai-Wang's avatar
Tai-Wang committed
171
172
173
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
174
175

        Returns:
Tai-Wang's avatar
Tai-Wang committed
176
177
178
179
180
181
182
183
184
185
186
            list[:obj:`Det3DDataSample`]: Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
187
        """
Tai-Wang's avatar
Tai-Wang committed
188
189
190
191
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
        predictions = self.convert_to_datasample(results_list)
        return predictions
192

Tai-Wang's avatar
Tai-Wang committed
193
194
195
196
    def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                 *args, **kwargs) -> Tuple[List[torch.Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.
197
198

        Args:
Tai-Wang's avatar
Tai-Wang committed
199
200
201
202
203
204
205
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
206
207

        Returns:
Tai-Wang's avatar
Tai-Wang committed
208
            tuple[list]: A tuple of features from ``bbox_head`` forward.
209
        """
Tai-Wang's avatar
Tai-Wang committed
210
211
212
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results = self.bbox_head.forward(x)
        return results