data_preprocessor.py 15.5 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Tuple, Union
5
6

import numpy as np
7
8
import torch
from mmcv.ops import Voxelization
9
from mmengine.model import stack_batch
10
from mmengine.utils import is_list_of
11
from torch.nn import functional as F
12
13

from mmdet3d.registry import MODELS
14
from mmdet3d.utils import OptConfigType
15
from mmdet.models import DetDataPreprocessor
16
from mmdet.models.utils.misc import samplelist_boxtype2tensor
17
18
from .utils import multiview_img_stack_batch

19
20
21
22


@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
23
24
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
25
26
27

    It provides the data pre-processing as follows

28
29
30
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
31
32
33
34
35
36
37
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
      ``pad_size_divisor``
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
        (3, H, W).
    - Normalize images in inputs with defined std and mean.
38
39
40
41
42
43
    - Do batch augmentations during training.

    - 2) For point cloud data:
    - if no voxelization, directly return list of point cloud data.
    - if voxelization is applied, voxelize point cloud according to
      ``voxel_type`` and obtain ``voxels``.
44
45

    Args:
46
47
48
49
50
51
        voxel (bool): Whether to apply voxelziation to point cloud.
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
        voxel_layer (:obj:`ConfigDict`, optional): Voxelization layer
            config. Defaults to None.
52
53
54
55
56
57
58
59
60
61
62
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
        bgr_to_rgb (bool): whether to convert image from BGR to RGB.
            Defaults to False.
        rgb_to_bgr (bool): whether to convert image from RGB to RGB.
            Defaults to False.
63
        batch_augments (list[dict], optional): Batch-level augmentations
64
65
66
    """

    def __init__(self,
67
68
69
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
70
71
72
73
74
75
76
77
78
79
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
80
                 boxtype2tensor: bool = True,
81
82
83
84
85
86
87
88
89
90
91
92
                 batch_augments: Optional[List[dict]] = None):
        super().__init__(
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
93
            boxtype2tensor=boxtype2tensor,
94
            batch_augments=batch_augments)
95
96
97
98
        self.voxel = voxel
        self.voxel_type = voxel_type
        if voxel:
            self.voxel_layer = Voxelization(**voxel_layer)
99

100
101
102
103
104
    def forward(
        self,
        data: Union[dict, List[dict]],
        training: bool = False
    ) -> Tuple[Union[dict, List[dict]], Optional[list]]:
105
106
107
108
        """Perform normalization、padding and bgr2rgb conversion based on
        ``BaseDataPreprocessor``.

        Args:
109
110
111
112
            data (dict | List[dict]): data from dataloader.
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.

113
            training (bool): Whether to enable training time augmentation.
114
                Defaults to False.
115
116

        Returns:
117
            Dict | List[Dict]: Data in the same format as the model input.
118
        """
119
120
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
121
122
            aug_batch_data = []
            for aug_id in range(num_augs):
123
124
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
125
                aug_batch_data.append(single_aug_batch_data)
126
            return aug_batch_data
jshilong's avatar
jshilong committed
127
128
129
130

        else:
            return self.simple_process(data, training)

131
132
133
134
    def simple_process(self, data: dict, training: bool = False) -> dict:
        """Perform normalization、padding and bgr2rgb conversion for img data
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
135

136
137
138
139
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
140

141
142
143
144
145
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
146

147
148
149
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
150

151
152
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
153

154
155
156
157
158
159
160
161
            if self.voxel:
                voxel_dict = self.voxelize(inputs['points'])
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
162
                # NOTE the batched image size information may be useful, e.g.
163
164
165
166
167
168
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
169
170
171
172
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

173
174
                if self.boxtype2tensor:
                    samplelist_boxtype2tensor(data_samples)
175

176
                if self.pad_mask:
177
                    self.pad_gt_masks(data_samples)
178
179

                if self.pad_seg:
180
                    self.pad_gt_sem_seg(data_samples)
181
182
183

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
184
185
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
186

187
        return {'inputs': batch_inputs, 'data_samples': data_samples}
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def preprocess_img(self, _batch_img):
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

206
207
208
209
    def collate_data(self, data: dict) -> dict:
        """Copying data to the target device and Performs normalization、
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
210
211
212
213
214

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
215
            data (dict): Data sampled from dataloader.
216
217

        Returns:
218
            dict: Data in the same format as the model input.
219
        """
220
221
222
223
224
225
226
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
227
                img_dim = _batch_imgs[0].dim()
228
                for _batch_img in _batch_imgs:
229
230
231
232
233
234
235
236
237
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

238
                    batch_imgs.append(_batch_img)
239

240
                # Pad and stack Tensor.
241
242
243
244
245
246
247
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
                    f'{type(data)}{data}')

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
283
284
285
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
        # rewrite `_get_pad_shape` for obaining image inputs.
286
287
288
289
290
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
291
292
293
294
                if ori_input.dim() == 4:
                    # mean multiivew input, select ont of the
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
                            'or a tuple with inputs and data_samples, but got'
318
                            f'{type(data)}: {data}')
319
        return batch_pad_shape
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

    @torch.no_grad()
    def voxelize(self, points: List[torch.Tensor]) -> Dict:
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.

        Returns:
            dict[str, Tensor]: Voxelization information.

            - voxels (Tensor): Features of voxels, shape is MXNxC for hard
                voxelization, NXC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is  Nx(1+NDim),
                where 1 represents the batch index.
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
            for res in points:
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
            coors_batch = []
            for i, coor in enumerate(coors):
                coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
                coors_batch.append(coor_pad)
            coors_batch = torch.cat(coors_batch, dim=0)
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
            for res in points:
                res_coors = self.voxel_layer(res)
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
            coors_batch = []
            for i, coor in enumerate(coors):
                coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
                coors_batch.append(coor_pad)
            coors_batch = torch.cat(coors_batch, dim=0)
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
        voxel_dict['coors'] = coors_batch

        return voxel_dict