imvoxelnet.py 11.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
Tai-Wang's avatar
Tai-Wang committed
2
3
from typing import List, Tuple, Union

4
import torch
5
from mmengine.structures import InstanceData
6

7
from mmdet3d.models.detectors import Base3DDetector
zhangshilong's avatar
zhangshilong committed
8
9
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS
10
from mmdet3d.structures.bbox_3d import get_proj_mat_by_coord_type
zhangshilong's avatar
zhangshilong committed
11
from mmdet3d.structures.det3d_data_sample import SampleList
12
from mmdet3d.utils import ConfigType, OptConfigType, OptInstanceList
13
14


15
@MODELS.register_module()
16
class ImVoxelNet(Base3DDetector):
Tai-Wang's avatar
Tai-Wang committed
17
18
19
20
21
22
23
    r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.

    Args:
        backbone (:obj:`ConfigDict` or dict): The backbone config.
        neck (:obj:`ConfigDict` or dict): The neck config.
        neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
        bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
24
25
        prior_generator (:obj:`ConfigDict` or dict): The prior points
            generator config.
Tai-Wang's avatar
Tai-Wang committed
26
        n_voxels (list): Number of voxels along x, y, z axis.
27
28
        coord_type (str): The type of coordinates of points cloud:
            'DEPTH', 'LIDAR', or 'CAMERA'.
Tai-Wang's avatar
Tai-Wang committed
29
30
31
32
33
34
35
36
37
38
        train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
            training hyper-parameters. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
            hyper-parameters. Defaults to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
                ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
        init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
            config. Defaults to None.
    """
39
40

    def __init__(self,
Tai-Wang's avatar
Tai-Wang committed
41
42
43
44
                 backbone: ConfigType,
                 neck: ConfigType,
                 neck_3d: ConfigType,
                 bbox_head: ConfigType,
45
                 prior_generator: ConfigType,
Tai-Wang's avatar
Tai-Wang committed
46
                 n_voxels: List,
47
                 coord_type: str,
Tai-Wang's avatar
Tai-Wang committed
48
49
50
51
52
53
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)
54
55
56
        self.backbone = MODELS.build(backbone)
        self.neck = MODELS.build(neck)
        self.neck_3d = MODELS.build(neck_3d)
57
58
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
59
        self.bbox_head = MODELS.build(bbox_head)
60
        self.prior_generator = TASK_UTILS.build(prior_generator)
61
        self.n_voxels = n_voxels
62
        self.coord_type = coord_type
63
64
65
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

Tai-Wang's avatar
Tai-Wang committed
66
67
    def extract_feat(self, batch_inputs_dict: dict,
                     batch_data_samples: SampleList):
68
69
        """Extract 3d features from the backbone -> fpn -> 3d projection.

70
71
        -> 3d neck -> bbox_head.

72
        Args:
Tai-Wang's avatar
Tai-Wang committed
73
74
75
76
77
78
79
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
80
81

        Returns:
82
83
84
            Tuple:
            - torch.Tensor: Features of shape (N, C_out, N_x, N_y, N_z).
            - torch.Tensor: Valid mask of shape (N, 1, N_x, N_y, N_z).
85
        """
Tai-Wang's avatar
Tai-Wang committed
86
87
88
89
        img = batch_inputs_dict['imgs']
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
90
91
        x = self.backbone(img)
        x = self.neck(x)[0]
92
93
94
        points = self.prior_generator.grid_anchors([self.n_voxels[::-1]],
                                                   device=img.device)[0][:, :3]
        volumes, valid_preds = [], []
Tai-Wang's avatar
Tai-Wang committed
95
        for feature, img_meta in zip(x, batch_img_metas):
96
97
98
99
100
101
102
            img_scale_factor = (
                points.new_tensor(img_meta['scale_factor'][:2])
                if 'scale_factor' in img_meta.keys() else 1)
            img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
            img_crop_offset = (
                points.new_tensor(img_meta['img_crop_offset'])
                if 'img_crop_offset' in img_meta.keys() else 0)
103
104
            proj_mat = points.new_tensor(
                get_proj_mat_by_coord_type(img_meta, self.coord_type))
105
106
107
108
            volume = point_sample(
                img_meta,
                img_features=feature[None, ...],
                points=points,
109
110
                proj_mat=points.new_tensor(proj_mat),
                coord_type=self.coord_type,
111
112
113
114
115
116
117
118
                img_scale_factor=img_scale_factor,
                img_crop_offset=img_crop_offset,
                img_flip=img_flip,
                img_pad_shape=img.shape[-2:],
                img_shape=img_meta['img_shape'][:2],
                aligned=False)
            volumes.append(
                volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
119
120
            valid_preds.append(
                ~torch.all(volumes[-1] == 0, dim=0, keepdim=True))
121
122
        x = torch.stack(volumes)
        x = self.neck_3d(x)
123
        return x, torch.stack(valid_preds).float()
124

Tai-Wang's avatar
Tai-Wang committed
125
126
127
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.
128
129

        Args:
Tai-Wang's avatar
Tai-Wang committed
130
131
132
133
134
135
136
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
137
138

        Returns:
Tai-Wang's avatar
Tai-Wang committed
139
            dict: A dictionary of loss components.
140
        """
141
142
143
144
145
146
        x, valid_preds = self.extract_feat(batch_inputs_dict,
                                           batch_data_samples)
        # For indoor datasets ImVoxelNet uses ImVoxelHead that handles
        # mask of visible voxels.
        if self.coord_type == 'DEPTH':
            x += (valid_preds, )
Tai-Wang's avatar
Tai-Wang committed
147
        losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
148
149
        return losses

Tai-Wang's avatar
Tai-Wang committed
150
151
152
153
    def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                **kwargs) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
154
155

        Args:
Tai-Wang's avatar
Tai-Wang committed
156
157
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.
158

Tai-Wang's avatar
Tai-Wang committed
159
                    - imgs (torch.Tensor, optional): Image of each sample.
160

Tai-Wang's avatar
Tai-Wang committed
161
162
163
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
164
165

        Returns:
Tai-Wang's avatar
Tai-Wang committed
166
167
168
169
170
171
172
173
174
175
176
            list[:obj:`Det3DDataSample`]: Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
177
        """
178
179
180
181
182
183
184
185
        x, valid_preds = self.extract_feat(batch_inputs_dict,
                                           batch_data_samples)
        # For indoor datasets ImVoxelNet uses ImVoxelHead that handles
        # mask of visible voxels.
        if self.coord_type == 'DEPTH':
            x += (valid_preds, )
        results_list = \
            self.bbox_head.predict(x, batch_data_samples, **kwargs)
186
187
        predictions = self.add_pred_to_datasample(batch_data_samples,
                                                  results_list)
Tai-Wang's avatar
Tai-Wang committed
188
        return predictions
189

Tai-Wang's avatar
Tai-Wang committed
190
191
192
193
    def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                 *args, **kwargs) -> Tuple[List[torch.Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.
194
195

        Args:
Tai-Wang's avatar
Tai-Wang committed
196
197
198
199
200
201
202
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
203
204

        Returns:
Tai-Wang's avatar
Tai-Wang committed
205
            tuple[list]: A tuple of features from ``bbox_head`` forward.
206
        """
207
208
209
210
211
212
        x, valid_preds = self.extract_feat(batch_inputs_dict,
                                           batch_data_samples)
        # For indoor datasets ImVoxelNet uses ImVoxelHead that handles
        # mask of visible voxels.
        if self.coord_type == 'DEPTH':
            x += (valid_preds, )
Tai-Wang's avatar
Tai-Wang committed
213
214
        results = self.bbox_head.forward(x)
        return results
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def convert_to_datasample(
        self,
        data_samples: SampleList,
        data_instances_3d: OptInstanceList = None,
        data_instances_2d: OptInstanceList = None,
    ) -> SampleList:
        """Convert results list to `Det3DDataSample`.

        Subclasses could override it to be compatible for some multi-modality
        3D detectors.

        Args:
            data_samples (list[:obj:`Det3DDataSample`]): The input data.
            data_instances_3d (list[:obj:`InstanceData`], optional): 3D
                Detection results of each sample.
            data_instances_2d (list[:obj:`InstanceData`], optional): 2D
                Detection results of each sample.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
            input. Each Det3DDataSample usually contains
            'pred_instances_3d'. And the ``pred_instances_3d`` normally
            contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instance, )
            - labels_3d (Tensor): Labels of 3D bboxes, has a shape
              (num_instances, ).
            - bboxes_3d (Tensor): Contains a tensor with shape
              (num_instances, C) where C >=7.

            When there are image prediction in some models, it should
            contains  `pred_instances`, And the ``pred_instances`` normally
            contains following keys.

            - scores (Tensor): Classification scores of image, has a shape
              (num_instance, )
            - labels (Tensor): Predict Labels of 2D bboxes, has a shape
              (num_instances, ).
            - bboxes (Tensor): Contains a tensor with shape
              (num_instances, 4).
        """

        assert (data_instances_2d is not None) or \
               (data_instances_3d is not None),\
               'please pass at least one type of data_samples'

        if data_instances_2d is None:
            data_instances_2d = [
                InstanceData() for _ in range(len(data_instances_3d))
            ]
        if data_instances_3d is None:
            data_instances_3d = [
                InstanceData() for _ in range(len(data_instances_2d))
            ]

        for i, data_sample in enumerate(data_samples):
            data_sample.pred_instances_3d = data_instances_3d[i]
            data_sample.pred_instances = data_instances_2d[i]
        return data_samples