data_preprocessor.py 14.8 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Tuple, Union
5
6

import numpy as np
7
8
import torch
from mmcv.ops import Voxelization
9
from mmengine.model import stack_batch
10
from mmengine.utils import is_list_of
11
from torch.nn import functional as F
12
13

from mmdet3d.registry import MODELS
14
from mmdet3d.utils import OptConfigType
15
from mmdet.models import DetDataPreprocessor
16
from mmdet.models.utils.misc import samplelist_boxtype2tensor
17
18
19
20


@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
21
22
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
23
24
25

    It provides the data pre-processing as follows

26
27
28
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
29
30
31
32
33
34
35
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
      ``pad_size_divisor``
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
        (3, H, W).
    - Normalize images in inputs with defined std and mean.
36
37
38
39
40
41
    - Do batch augmentations during training.

    - 2) For point cloud data:
    - if no voxelization, directly return list of point cloud data.
    - if voxelization is applied, voxelize point cloud according to
      ``voxel_type`` and obtain ``voxels``.
42
43

    Args:
44
45
46
47
48
49
        voxel (bool): Whether to apply voxelziation to point cloud.
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
        voxel_layer (:obj:`ConfigDict`, optional): Voxelization layer
            config. Defaults to None.
50
51
52
53
54
55
56
57
58
59
60
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
        bgr_to_rgb (bool): whether to convert image from BGR to RGB.
            Defaults to False.
        rgb_to_bgr (bool): whether to convert image from RGB to RGB.
            Defaults to False.
61
        batch_augments (list[dict], optional): Batch-level augmentations
62
63
64
    """

    def __init__(self,
65
66
67
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
68
69
70
71
72
73
74
75
76
77
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
78
                 boxtype2tensor: bool = True,
79
80
81
82
83
84
85
86
87
88
89
90
                 batch_augments: Optional[List[dict]] = None):
        super().__init__(
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
91
            boxtype2tensor=boxtype2tensor,
92
            batch_augments=batch_augments)
93
94
95
96
        self.voxel = voxel
        self.voxel_type = voxel_type
        if voxel:
            self.voxel_layer = Voxelization(**voxel_layer)
97

98
99
100
101
102
    def forward(
        self,
        data: Union[dict, List[dict]],
        training: bool = False
    ) -> Tuple[Union[dict, List[dict]], Optional[list]]:
103
104
105
106
        """Perform normalization、padding and bgr2rgb conversion based on
        ``BaseDataPreprocessor``.

        Args:
107
108
109
110
            data (dict | List[dict]): data from dataloader.
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.

111
            training (bool): Whether to enable training time augmentation.
112
                Defaults to False.
113
114

        Returns:
115
            Dict | List[Dict]: Data in the same format as the model input.
116
        """
117
118
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
119
120
            aug_batch_data = []
            for aug_id in range(num_augs):
121
122
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
123
                aug_batch_data.append(single_aug_batch_data)
124
            return aug_batch_data
jshilong's avatar
jshilong committed
125
126
127
128

        else:
            return self.simple_process(data, training)

129
130
131
132
    def simple_process(self, data: dict, training: bool = False) -> dict:
        """Perform normalization、padding and bgr2rgb conversion for img data
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
133

134
135
136
137
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
138

139
140
141
142
143
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
144

145
146
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
147

148
        batch_inputs = dict()
149

150
151
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
152

153
154
155
156
157
158
159
160
            if self.voxel:
                voxel_dict = self.voxelize(inputs['points'])
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
161
                # NOTE the batched image size information may be useful, e.g.
162
163
164
165
166
167
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
168
169
170
171
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

172
173
                if self.boxtype2tensor:
                    samplelist_boxtype2tensor(data_samples)
174

175
                if self.pad_mask:
176
                    self.pad_gt_masks(data_samples)
177
178

                if self.pad_seg:
179
                    self.pad_gt_sem_seg(data_samples)
180
181
182

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
183
184
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
185

186
        return {'inputs': batch_inputs, 'data_samples': data_samples}
187

188
189
190
191
    def collate_data(self, data: dict) -> dict:
        """Copying data to the target device and Performs normalization、
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
192
193
194
195
196

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
197
            data (dict): Data sampled from dataloader.
198
199

        Returns:
200
            dict: Data in the same format as the model input.
201
        """
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']

            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
                for _batch_img in _batch_imgs:
                    # channel transform
                    if self._channel_conversion:
                        _batch_img = _batch_img[[2, 1, 0], ...]
                    # Convert to float after channel conversion to ensure
                    # efficiency
                    _batch_img = _batch_img.float()
                    # Normalization.
                    if self._enable_normalize:
                        if self.mean.shape[0] == 3:
                            assert _batch_img.dim(
                            ) == 3 and _batch_img.shape[0] == 3, (
                                'If the mean has 3 values, the input tensor '
                                'should in shape of (3, H, W), but got the '
                                f'tensor with shape {_batch_img.shape}')
                        _batch_img = (_batch_img - self.mean) / self.std
                    batch_imgs.append(_batch_img)
                # Pad and stack Tensor.
                batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                         self.pad_value)
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
                    f'{type(data)}{data}')

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
265
266
267
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
        # rewrite `_get_pad_shape` for obaining image inputs.
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
                            'or a tuple with inputs and data_samples, but got'
296
                            f'{type(data)}: {data}')
297
        return batch_pad_shape
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    @torch.no_grad()
    def voxelize(self, points: List[torch.Tensor]) -> Dict:
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.

        Returns:
            dict[str, Tensor]: Voxelization information.

            - voxels (Tensor): Features of voxels, shape is MXNxC for hard
                voxelization, NXC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is  Nx(1+NDim),
                where 1 represents the batch index.
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
            for res in points:
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
            coors_batch = []
            for i, coor in enumerate(coors):
                coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
                coors_batch.append(coor_pad)
            coors_batch = torch.cat(coors_batch, dim=0)
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
            for res in points:
                res_coors = self.voxel_layer(res)
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
            coors_batch = []
            for i, coor in enumerate(coors):
                coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
                coors_batch.append(coor_pad)
            coors_batch = torch.cat(coors_batch, dim=0)
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
        voxel_dict['coors'] = coors_batch

        return voxel_dict