data_preprocessor.py 16.2 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Union
5
6

import numpy as np
7
8
import torch
from mmcv.ops import Voxelization
9
from mmdet.models import DetDataPreprocessor
10
from mmengine.model import stack_batch
11
from mmengine.utils import is_list_of
12
from torch.nn import functional as F
13
14

from mmdet3d.registry import MODELS
15
from mmdet3d.utils import OptConfigType
16
17
from .utils import multiview_img_stack_batch

18
19
20

@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
21
22
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
23
24
25

    It provides the data pre-processing as follows

26
27
28
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
29
30
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
31
      ``pad_size_divisor``.
32
33
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
34
      (3, H, W).
35
    - Normalize images in inputs with defined std and mean.
36
37
38
    - Do batch augmentations during training.

    - 2) For point cloud data:
39
40
    - If no voxelization, directly return list of point cloud data.
    - If voxelization is applied, voxelize point cloud according to
41
      ``voxel_type`` and obtain ``voxels``.
42
43

    Args:
44
45
        voxel (bool): Whether to apply voxelization to point cloud.
            Defaults to False.
46
47
48
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
49
        voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
50
            config. Defaults to None.
51
52
53
54
55
56
57
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
58
59
60
61
62
63
64
65
        pad_mask (bool): Whether to pad instance masks. Defaults to False.
        mask_pad_value (int): The padded pixel value for instance masks.
            Defaults to 0.
        pad_seg (bool): Whether to pad semantic segmentation maps.
            Defaults to False.
        seg_pad_value (int): The padded pixel value for semantic
            segmentation maps. Defaults to 255.
        bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
66
            Defaults to False.
67
        rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
68
            Defaults to False.
69
70
71
72
        boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
            bboxes data or not. Defaults to True.
        batch_augments (List[dict], optional): Batch-level augmentations.
            Defaults to None.
73
74
75
    """

    def __init__(self,
76
77
78
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
79
80
81
82
83
84
85
86
87
88
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
89
                 boxtype2tensor: bool = True,
90
91
                 batch_augments: Optional[List[dict]] = None) -> None:
        super(Det3DDataPreprocessor).__init__(
92
93
94
95
96
97
98
99
100
101
102
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
            batch_augments=batch_augments)
103
104
105
106
        self.voxel = voxel
        self.voxel_type = voxel_type
        if voxel:
            self.voxel_layer = Voxelization(**voxel_layer)
107

108
109
110
111
    def forward(self,
                data: Union[dict, List[dict]],
                training: bool = False) -> Union[dict, List[dict]]:
        """Perform normalization, padding and bgr2rgb conversion based on
112
113
114
        ``BaseDataPreprocessor``.

        Args:
115
            data (dict or List[dict]): Data from dataloader.
116
117
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.
118
            training (bool): Whether to enable training time augmentation.
119
                Defaults to False.
120
121

        Returns:
122
            dict or List[dict]: Data in the same format as the model input.
123
        """
124
125
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
126
127
            aug_batch_data = []
            for aug_id in range(num_augs):
128
129
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
130
                aug_batch_data.append(single_aug_batch_data)
131
            return aug_batch_data
jshilong's avatar
jshilong committed
132
133
134
135

        else:
            return self.simple_process(data, training)

136
    def simple_process(self, data: dict, training: bool = False) -> dict:
137
        """Perform normalization, padding and bgr2rgb conversion for img data
138
139
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
140

141
142
143
144
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
145

146
147
148
149
150
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
151

152
153
154
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
155

156
157
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
158

159
160
161
162
163
164
165
166
            if self.voxel:
                voxel_dict = self.voxelize(inputs['points'])
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
167
                # NOTE the batched image size information may be useful, e.g.
168
169
170
171
172
173
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
174
175
176
177
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

VVsssssk's avatar
VVsssssk committed
178
179
180
                if hasattr(self, 'boxtype2tensor') and self.boxtype2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxtype2tensor
181
                    samplelist_boxtype2tensor(data_samples)
VVsssssk's avatar
VVsssssk committed
182
183
184
185
                elif hasattr(self, 'boxlist2tensor') and self.boxlist2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxlist2tensor
                    samplelist_boxlist2tensor(data_samples)
186
                if self.pad_mask:
187
                    self.pad_gt_masks(data_samples)
188
189

                if self.pad_seg:
190
                    self.pad_gt_sem_seg(data_samples)
191
192
193

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
194
195
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
196

197
        return {'inputs': batch_inputs, 'data_samples': data_samples}
198

199
    def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

216
    def collate_data(self, data: dict) -> dict:
217
        """Copying data to the target device and Performs normalization,
218
219
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
220
221
222
223
224

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
225
            data (dict): Data sampled from dataloader.
226
227

        Returns:
228
            dict: Data in the same format as the model input.
229
        """
230
231
232
233
234
235
236
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
237
                img_dim = _batch_imgs[0].dim()
238
                for _batch_img in _batch_imgs:
239
240
241
242
243
244
245
246
247
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

248
                    batch_imgs.append(_batch_img)
249

250
                # Pad and stack Tensor.
251
252
253
254
255
256
257
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
284
                    f'{type(data)}: {data}')
285
286
287
288
289
290
291
292

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
293
294
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
295
        # rewrite `_get_pad_shape` for obtaining image inputs.
296
297
298
299
300
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
301
                if ori_input.dim() == 4:
302
                    # mean multiview input, select one of the
303
304
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
327
                            'or a tuple with inputs and data_samples, but got '
328
                            f'{type(data)}: {data}')
329
        return batch_pad_shape
330
331

    @torch.no_grad()
332
    def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
333
334
335
336
337
338
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.

        Returns:
339
            Dict[str, Tensor]: Voxelization information.
340

341
342
343
344
            - voxels (Tensor): Features of voxels, shape is MxNxC for hard
              voxelization, NxC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
              where 1 represents the batch index.
345
346
347
348
349
350
351
352
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
353
            for i, res in enumerate(points):
354
355
356
357
358
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
359
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
360
361
362
363
364
365
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
366
            coors = torch.cat(coors, dim=0)
367
368
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
369

370
371
372
373
374
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
375
            for i, res in enumerate(points):
376
                res_coors = self.voxel_layer(res)
377
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
378
379
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
380
            coors = torch.cat(coors, dim=0)
381
382
383
384
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
385
        voxel_dict['coors'] = coors
386
387

        return voxel_dict