data_preprocessor.py 18.7 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Union
5
6

import numpy as np
7
import torch
8
from mmdet.models import DetDataPreprocessor
9
from mmengine.model import stack_batch
10
from mmengine.utils import is_list_of
11
from torch.nn import functional as F
12
13

from mmdet3d.registry import MODELS
14
from mmdet3d.structures.det3d_data_sample import SampleList
15
from mmdet3d.utils import OptConfigType
16
from .utils import multiview_img_stack_batch
17
from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d
18

19
20
21

@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
22
23
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
24
25
26

    It provides the data pre-processing as follows

27
28
29
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
30
31
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
32
      ``pad_size_divisor``.
33
34
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
35
      (3, H, W).
36
    - Normalize images in inputs with defined std and mean.
37
38
39
    - Do batch augmentations during training.

    - 2) For point cloud data:
40
41
    - If no voxelization, directly return list of point cloud data.
    - If voxelization is applied, voxelize point cloud according to
42
      ``voxel_type`` and obtain ``voxels``.
43
44

    Args:
45
46
        voxel (bool): Whether to apply voxelization to point cloud.
            Defaults to False.
47
48
49
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
50
        voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
51
            config. Defaults to None.
52
53
54
55
56
57
58
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
59
60
61
62
63
64
65
66
        pad_mask (bool): Whether to pad instance masks. Defaults to False.
        mask_pad_value (int): The padded pixel value for instance masks.
            Defaults to 0.
        pad_seg (bool): Whether to pad semantic segmentation maps.
            Defaults to False.
        seg_pad_value (int): The padded pixel value for semantic
            segmentation maps. Defaults to 255.
        bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
67
            Defaults to False.
68
        rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
69
            Defaults to False.
70
71
72
73
        boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
            bboxes data or not. Defaults to True.
        batch_augments (List[dict], optional): Batch-level augmentations.
            Defaults to None.
74
75
76
    """

    def __init__(self,
77
78
79
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
80
81
82
83
84
85
86
87
88
89
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
90
                 boxtype2tensor: bool = True,
91
                 batch_augments: Optional[List[dict]] = None) -> None:
Xiangxu-0103's avatar
Xiangxu-0103 committed
92
        super(Det3DDataPreprocessor, self).__init__(
93
94
95
96
97
98
99
100
101
102
103
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
            batch_augments=batch_augments)
104
105
106
        self.voxel = voxel
        self.voxel_type = voxel_type
        if voxel:
107
            self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
108

109
110
111
112
    def forward(self,
                data: Union[dict, List[dict]],
                training: bool = False) -> Union[dict, List[dict]]:
        """Perform normalization, padding and bgr2rgb conversion based on
113
114
115
        ``BaseDataPreprocessor``.

        Args:
116
            data (dict or List[dict]): Data from dataloader.
117
118
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.
119
            training (bool): Whether to enable training time augmentation.
120
                Defaults to False.
121
122

        Returns:
123
            dict or List[dict]: Data in the same format as the model input.
124
        """
125
126
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
127
128
            aug_batch_data = []
            for aug_id in range(num_augs):
129
130
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
131
                aug_batch_data.append(single_aug_batch_data)
132
            return aug_batch_data
jshilong's avatar
jshilong committed
133
134
135
136

        else:
            return self.simple_process(data, training)

137
    def simple_process(self, data: dict, training: bool = False) -> dict:
138
        """Perform normalization, padding and bgr2rgb conversion for img data
139
140
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
141

142
143
144
145
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
146

147
148
149
150
151
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
152

153
154
155
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
156

157
158
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
159

160
            if self.voxel:
161
                voxel_dict = self.voxelize(inputs['points'], data_samples)
162
163
164
165
166
167
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
168
                # NOTE the batched image size information may be useful, e.g.
169
170
171
172
173
174
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
175
176
177
178
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

VVsssssk's avatar
VVsssssk committed
179
180
181
                if hasattr(self, 'boxtype2tensor') and self.boxtype2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxtype2tensor
182
                    samplelist_boxtype2tensor(data_samples)
VVsssssk's avatar
VVsssssk committed
183
184
185
186
                elif hasattr(self, 'boxlist2tensor') and self.boxlist2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxlist2tensor
                    samplelist_boxlist2tensor(data_samples)
187
                if self.pad_mask:
188
                    self.pad_gt_masks(data_samples)
189
190

                if self.pad_seg:
191
                    self.pad_gt_sem_seg(data_samples)
192
193
194

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
195
196
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
197

198
        return {'inputs': batch_inputs, 'data_samples': data_samples}
199

200
    def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

217
    def collate_data(self, data: dict) -> dict:
218
        """Copying data to the target device and Performs normalization,
219
220
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
221
222
223
224
225

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
226
            data (dict): Data sampled from dataloader.
227
228

        Returns:
229
            dict: Data in the same format as the model input.
230
        """
231
232
233
234
235
236
237
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
238
                img_dim = _batch_imgs[0].dim()
239
                for _batch_img in _batch_imgs:
240
241
242
243
244
245
246
247
248
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

249
                    batch_imgs.append(_batch_img)
250

251
                # Pad and stack Tensor.
252
253
254
255
256
257
258
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
285
                    f'{type(data)}: {data}')
286
287
288
289
290
291
292
293

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
294
295
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
296
        # rewrite `_get_pad_shape` for obtaining image inputs.
297
298
299
300
301
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
302
                if ori_input.dim() == 4:
303
                    # mean multiview input, select one of the
304
305
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
328
                            'or a tuple with inputs and data_samples, but got '
329
                            f'{type(data)}: {data}')
330
        return batch_pad_shape
331
332

    @torch.no_grad()
333
334
    def voxelize(self, points: List[torch.Tensor],
                 data_samples: SampleList) -> Dict[str, torch.Tensor]:
335
336
337
338
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.
339
340
            data_samples: (list[:obj:`Det3DDataSample`]): The annotation data
                of every samples. Add voxel-wise annotation for segmentation.
341
342

        Returns:
343
            Dict[str, Tensor]: Voxelization information.
344

345
346
347
348
            - voxels (Tensor): Features of voxels, shape is MxNxC for hard
              voxelization, NxC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
              where 1 represents the batch index.
349
350
351
352
353
354
355
356
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
357
            for i, res in enumerate(points):
358
359
360
361
362
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
363
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
364
365
366
367
368
369
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
370
            coors = torch.cat(coors, dim=0)
371
372
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
373

374
375
376
377
378
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
379
            for i, res in enumerate(points):
380
                res_coors = self.voxel_layer(res)
381
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
382
383
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
384
            coors = torch.cat(coors, dim=0)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        elif self.voxel_type == 'cylindrical':
            voxels, coors = [], []
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
                phi = torch.atan2(res[:, 1], res[:, 0])
                polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
                # Currently we only support PyTorch >= 1.9.0, and will
                # implement it in voxel_layer soon for better compatibility
                min_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[:3])
                max_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[3:])
                polar_res = torch.clamp(polar_res, min_bound, max_bound)
                res_coors = torch.floor(
                    (polar_res - min_bound) /
                    polar_res.new_tensor(self.voxel_layer.voxel_size)).int()
                if self.training:
                    self.get_voxel_seg(res_coors, data_sample)
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
                res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
                                       dim=-1)
                voxels.append(res_voxels)
                coors.append(res_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)
410
411
412
413
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
414
        voxel_dict['coors'] = coors
415
416

        return voxel_dict
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

    def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
        """Get voxel-wise segmentation label and point2voxel map.

        Args:
            res_coors (Tensor): The voxel coordinates of points, Nx3.
            data_sample: (:obj:`Det3DDataSample`): The annotation data of
                every samples. Add voxel-wise annotation forsegmentation.
        """
        pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
        voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
            F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
            True)
        voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
        data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
        data_sample.gt_pts_seg.point2voxel_map = point2voxel_map