data_preprocessor.py 15.8 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Tuple, Union
5
6

import numpy as np
7
8
import torch
from mmcv.ops import Voxelization
9
from mmengine.model import stack_batch
10
from mmengine.utils import is_list_of
11
from torch.nn import functional as F
12
13

from mmdet3d.registry import MODELS
14
from mmdet3d.utils import OptConfigType
15
from mmdet.models import DetDataPreprocessor
16
17
from .utils import multiview_img_stack_batch

18
19
20

@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
21
22
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
23
24
25

    It provides the data pre-processing as follows

26
27
28
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
29
30
31
32
33
34
35
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
      ``pad_size_divisor``
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
        (3, H, W).
    - Normalize images in inputs with defined std and mean.
36
37
38
39
40
41
    - Do batch augmentations during training.

    - 2) For point cloud data:
    - if no voxelization, directly return list of point cloud data.
    - if voxelization is applied, voxelize point cloud according to
      ``voxel_type`` and obtain ``voxels``.
42
43

    Args:
44
45
46
47
48
49
        voxel (bool): Whether to apply voxelziation to point cloud.
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
        voxel_layer (:obj:`ConfigDict`, optional): Voxelization layer
            config. Defaults to None.
50
51
52
53
54
55
56
57
58
59
60
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
        bgr_to_rgb (bool): whether to convert image from BGR to RGB.
            Defaults to False.
        rgb_to_bgr (bool): whether to convert image from RGB to RGB.
            Defaults to False.
61
        batch_augments (list[dict], optional): Batch-level augmentations
62
63
64
    """

    def __init__(self,
65
66
67
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
68
69
70
71
72
73
74
75
76
77
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
78
                 boxtype2tensor: bool = True,
79
80
81
82
83
84
85
86
87
88
89
90
91
                 batch_augments: Optional[List[dict]] = None):
        super().__init__(
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
            batch_augments=batch_augments)
92
93
94
95
        self.voxel = voxel
        self.voxel_type = voxel_type
        if voxel:
            self.voxel_layer = Voxelization(**voxel_layer)
96

97
98
99
100
101
    def forward(
        self,
        data: Union[dict, List[dict]],
        training: bool = False
    ) -> Tuple[Union[dict, List[dict]], Optional[list]]:
102
103
104
105
        """Perform normalization、padding and bgr2rgb conversion based on
        ``BaseDataPreprocessor``.

        Args:
106
107
108
109
            data (dict | List[dict]): data from dataloader.
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.

110
            training (bool): Whether to enable training time augmentation.
111
                Defaults to False.
112
113

        Returns:
114
            Dict | List[Dict]: Data in the same format as the model input.
115
        """
116
117
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
118
119
            aug_batch_data = []
            for aug_id in range(num_augs):
120
121
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
122
                aug_batch_data.append(single_aug_batch_data)
123
            return aug_batch_data
jshilong's avatar
jshilong committed
124
125
126
127

        else:
            return self.simple_process(data, training)

128
129
130
131
    def simple_process(self, data: dict, training: bool = False) -> dict:
        """Perform normalization、padding and bgr2rgb conversion for img data
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
132

133
134
135
136
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
137

138
139
140
141
142
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
143

144
145
146
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
147

148
149
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
150

151
152
153
154
155
156
157
158
            if self.voxel:
                voxel_dict = self.voxelize(inputs['points'])
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
159
                # NOTE the batched image size information may be useful, e.g.
160
161
162
163
164
165
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
166
167
168
169
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

VVsssssk's avatar
VVsssssk committed
170
171
172
                if hasattr(self, 'boxtype2tensor') and self.boxtype2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxtype2tensor
173
                    samplelist_boxtype2tensor(data_samples)
VVsssssk's avatar
VVsssssk committed
174
175
176
177
                elif hasattr(self, 'boxlist2tensor') and self.boxlist2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxlist2tensor
                    samplelist_boxlist2tensor(data_samples)
178
                if self.pad_mask:
179
                    self.pad_gt_masks(data_samples)
180
181

                if self.pad_seg:
182
                    self.pad_gt_sem_seg(data_samples)
183
184
185

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
186
187
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
188

189
        return {'inputs': batch_inputs, 'data_samples': data_samples}
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    def preprocess_img(self, _batch_img):
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

208
209
210
211
    def collate_data(self, data: dict) -> dict:
        """Copying data to the target device and Performs normalization、
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
212
213
214
215
216

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
217
            data (dict): Data sampled from dataloader.
218
219

        Returns:
220
            dict: Data in the same format as the model input.
221
        """
222
223
224
225
226
227
228
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
229
                img_dim = _batch_imgs[0].dim()
230
                for _batch_img in _batch_imgs:
231
232
233
234
235
236
237
238
239
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

240
                    batch_imgs.append(_batch_img)
241

242
                # Pad and stack Tensor.
243
244
245
246
247
248
249
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
                    f'{type(data)}{data}')

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
285
286
287
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
        # rewrite `_get_pad_shape` for obaining image inputs.
288
289
290
291
292
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
293
294
295
296
                if ori_input.dim() == 4:
                    # mean multiivew input, select ont of the
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
                            'or a tuple with inputs and data_samples, but got'
320
                            f'{type(data)}: {data}')
321
        return batch_pad_shape
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

    @torch.no_grad()
    def voxelize(self, points: List[torch.Tensor]) -> Dict:
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.

        Returns:
            dict[str, Tensor]: Voxelization information.

            - voxels (Tensor): Features of voxels, shape is MXNxC for hard
                voxelization, NXC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is  Nx(1+NDim),
                where 1 represents the batch index.
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
            for res in points:
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
            coors_batch = []
            for i, coor in enumerate(coors):
                coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
                coors_batch.append(coor_pad)
            coors_batch = torch.cat(coors_batch, dim=0)
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
            for res in points:
                res_coors = self.voxel_layer(res)
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
            coors_batch = []
            for i, coor in enumerate(coors):
                coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
                coors_batch.append(coor_pad)
            coors_batch = torch.cat(coors_batch, dim=0)
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
        voxel_dict['coors'] = coors_batch

        return voxel_dict