data_preprocessor.py 22.5 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Union
5
6

import numpy as np
7
import torch
8
from mmdet.models import DetDataPreprocessor
9
from mmengine.model import stack_batch
10
from mmengine.utils import is_list_of
11
from torch.nn import functional as F
12
13

from mmdet3d.registry import MODELS
14
from mmdet3d.structures.det3d_data_sample import SampleList
15
from mmdet3d.utils import OptConfigType
16
from .utils import multiview_img_stack_batch
17
from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d
18

19
20
21

@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
22
23
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
24
25
26

    It provides the data pre-processing as follows

27
28
29
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
30
31
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
32
      ``pad_size_divisor``.
33
34
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
35
      (3, H, W).
36
    - Normalize images in inputs with defined std and mean.
37
38
39
    - Do batch augmentations during training.

    - 2) For point cloud data:
40
41
    - If no voxelization, directly return list of point cloud data.
    - If voxelization is applied, voxelize point cloud according to
42
      ``voxel_type`` and obtain ``voxels``.
43
44

    Args:
45
46
        voxel (bool): Whether to apply voxelization to point cloud.
            Defaults to False.
47
48
49
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
50
        voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
51
            config. Defaults to None.
52
53
54
55
56
57
58
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
59
60
61
62
63
64
65
66
        pad_mask (bool): Whether to pad instance masks. Defaults to False.
        mask_pad_value (int): The padded pixel value for instance masks.
            Defaults to 0.
        pad_seg (bool): Whether to pad semantic segmentation maps.
            Defaults to False.
        seg_pad_value (int): The padded pixel value for semantic
            segmentation maps. Defaults to 255.
        bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
67
            Defaults to False.
68
        rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
69
            Defaults to False.
70
71
72
73
        boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
            bboxes data or not. Defaults to True.
        batch_augments (List[dict], optional): Batch-level augmentations.
            Defaults to None.
74
75
76
    """

    def __init__(self,
77
78
79
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
80
81
82
83
84
85
86
87
88
89
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
90
                 boxtype2tensor: bool = True,
91
                 batch_augments: Optional[List[dict]] = None) -> None:
Xiangxu-0103's avatar
Xiangxu-0103 committed
92
        super(Det3DDataPreprocessor, self).__init__(
93
94
95
96
97
98
99
100
101
102
103
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
            batch_augments=batch_augments)
104
105
106
        self.voxel = voxel
        self.voxel_type = voxel_type
        if voxel:
107
            self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
108

109
110
111
112
    def forward(self,
                data: Union[dict, List[dict]],
                training: bool = False) -> Union[dict, List[dict]]:
        """Perform normalization, padding and bgr2rgb conversion based on
113
114
115
        ``BaseDataPreprocessor``.

        Args:
116
            data (dict or List[dict]): Data from dataloader.
117
118
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.
119
            training (bool): Whether to enable training time augmentation.
120
                Defaults to False.
121
122

        Returns:
123
            dict or List[dict]: Data in the same format as the model input.
124
        """
125
126
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
127
128
            aug_batch_data = []
            for aug_id in range(num_augs):
129
130
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
131
                aug_batch_data.append(single_aug_batch_data)
132
            return aug_batch_data
jshilong's avatar
jshilong committed
133
134
135
136

        else:
            return self.simple_process(data, training)

137
    def simple_process(self, data: dict, training: bool = False) -> dict:
138
        """Perform normalization, padding and bgr2rgb conversion for img data
139
140
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
141

142
143
144
145
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
146

147
148
149
150
151
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
152

153
154
155
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
156

157
158
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
159

160
            if self.voxel:
161
                voxel_dict = self.voxelize(inputs['points'], data_samples)
162
163
164
165
166
167
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
168
                # NOTE the batched image size information may be useful, e.g.
169
170
171
172
173
174
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
175
176
177
178
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

VVsssssk's avatar
VVsssssk committed
179
180
181
                if hasattr(self, 'boxtype2tensor') and self.boxtype2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxtype2tensor
182
                    samplelist_boxtype2tensor(data_samples)
VVsssssk's avatar
VVsssssk committed
183
184
185
186
                elif hasattr(self, 'boxlist2tensor') and self.boxlist2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxlist2tensor
                    samplelist_boxlist2tensor(data_samples)
187
                if self.pad_mask:
188
                    self.pad_gt_masks(data_samples)
189
190

                if self.pad_seg:
191
                    self.pad_gt_sem_seg(data_samples)
192
193
194

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
195
196
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
197

198
        return {'inputs': batch_inputs, 'data_samples': data_samples}
199

200
    def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

217
    def collate_data(self, data: dict) -> dict:
218
        """Copying data to the target device and Performs normalization,
219
220
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
221
222
223
224
225

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
226
            data (dict): Data sampled from dataloader.
227
228

        Returns:
229
            dict: Data in the same format as the model input.
230
        """
231
232
233
234
235
236
237
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
238
                img_dim = _batch_imgs[0].dim()
239
                for _batch_img in _batch_imgs:
240
241
242
243
244
245
246
247
248
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

249
                    batch_imgs.append(_batch_img)
250

251
                # Pad and stack Tensor.
252
253
254
255
256
257
258
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
285
                    f'{type(data)}: {data}')
286
287
288
289
290
291
292
293

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
294
295
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
296
        # rewrite `_get_pad_shape` for obtaining image inputs.
297
298
299
300
301
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
302
                if ori_input.dim() == 4:
303
                    # mean multiview input, select one of the
304
305
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
328
                            'or a tuple with inputs and data_samples, but got '
329
                            f'{type(data)}: {data}')
330
        return batch_pad_shape
331
332

    @torch.no_grad()
333
334
    def voxelize(self, points: List[torch.Tensor],
                 data_samples: SampleList) -> Dict[str, torch.Tensor]:
335
336
337
338
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.
339
340
            data_samples: (list[:obj:`Det3DDataSample`]): The annotation data
                of every samples. Add voxel-wise annotation for segmentation.
341
342

        Returns:
343
            Dict[str, Tensor]: Voxelization information.
344

345
346
347
348
            - voxels (Tensor): Features of voxels, shape is MxNxC for hard
              voxelization, NxC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
              where 1 represents the batch index.
349
350
351
352
353
354
355
356
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
357
            for i, res in enumerate(points):
358
359
360
361
362
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
363
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
364
365
366
367
368
369
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
370
            coors = torch.cat(coors, dim=0)
371
372
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
373

374
375
376
377
378
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
379
            for i, res in enumerate(points):
380
                res_coors = self.voxel_layer(res)
381
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
382
383
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
384
            coors = torch.cat(coors, dim=0)
385
386
387
388
389
390
391
392
393
394
        elif self.voxel_type == 'cylindrical':
            voxels, coors = [], []
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
                phi = torch.atan2(res[:, 1], res[:, 0])
                polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
                min_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[:3])
                max_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[3:])
395
396
397
398
399
400
401
402
403
404
405
406
                try:  # only support PyTorch >= 1.9.0
                    polar_res_clamp = torch.clamp(polar_res, min_bound,
                                                  max_bound)
                except TypeError:
                    polar_res_clamp = polar_res.clone()
                    for coor_idx in range(3):
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] >
                            max_bound[coor_idx]] = max_bound[coor_idx]
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] <
                            min_bound[coor_idx]] = min_bound[coor_idx]
407
                res_coors = torch.floor(
408
409
410
                    (polar_res_clamp - min_bound) / polar_res_clamp.new_tensor(
                        self.voxel_layer.voxel_size)).int()
                self.get_voxel_seg(res_coors, data_sample)
411
412
413
414
415
416
417
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
                res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
                                       dim=-1)
                voxels.append(res_voxels)
                coors.append(res_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        elif self.voxel_type == 'minkunet':
            voxels, coors = [], []
            voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                res_coors = torch.round(res[:, :3] / voxel_size).int()
                res_coors -= res_coors.min(0)[0]

                res_coors_numpy = res_coors.cpu().numpy()
                inds, voxel2point_map = self.sparse_quantize(
                    res_coors_numpy, return_index=True, return_inverse=True)
                voxel2point_map = torch.from_numpy(voxel2point_map).cuda()
                if self.training:
                    if len(inds) > 80000:
                        inds = np.random.choice(inds, 80000, replace=False)
                inds = torch.from_numpy(inds).cuda()
                data_sample.gt_pts_seg.voxel_semantic_mask \
                    = data_sample.gt_pts_seg.pts_semantic_mask[inds]
                res_voxel_coors = res_coors[inds]
                res_voxels = res[inds]
                res_voxel_coors = F.pad(
                    res_voxel_coors, (0, 1), mode='constant', value=i)
                data_sample.voxel2point_map = voxel2point_map.long()
                voxels.append(res_voxels)
                coors.append(res_voxel_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)

445
446
447
448
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
449
        voxel_dict['coors'] = coors
450
451

        return voxel_dict
452
453
454
455
456
457
458
459
460

    def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
        """Get voxel-wise segmentation label and point2voxel map.

        Args:
            res_coors (Tensor): The voxel coordinates of points, Nx3.
            data_sample: (:obj:`Det3DDataSample`): The annotation data of
                every samples. Add voxel-wise annotation forsegmentation.
        """
461
462
463
464
465
466
467
468
469
470
471
472
473
474

        if self.training:
            pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
            voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
                F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
                True)
            voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
            data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
            data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
        else:
            pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
            _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
                                                       res_coors, 'mean', True)
            data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524

    def ravel_hash(self, x: np.ndarray) -> np.ndarray:
        """Get voxel coordinates hash for np.unique().

        Args:
            x (np.ndarray): The voxel coordinates of points, Nx3.

        Returns:
            np.ndarray: Voxels coordinates hash.
        """
        assert x.ndim == 2, x.shape

        x = x - np.min(x, axis=0)
        x = x.astype(np.uint64, copy=False)
        xmax = np.max(x, axis=0).astype(np.uint64) + 1

        h = np.zeros(x.shape[0], dtype=np.uint64)
        for k in range(x.shape[1] - 1):
            h += x[:, k]
            h *= xmax[k + 1]
        h += x[:, -1]
        return h

    def sparse_quantize(self,
                        coords: np.ndarray,
                        return_index: bool = False,
                        return_inverse: bool = False) -> List[np.ndarray]:
        """Sparse Quantization for voxel coordinates used in Minkunet.

        Args:
            coords (np.ndarray): The voxel coordinates of points, Nx3.
            return_index (bool): Whether to return the indices of the
                unique coords, shape (M,).
            return_inverse (bool): Whether to return the indices of the
                original coords shape (N,).

        Returns:
            List[np.ndarray] or None: Return index and inverse map if
            return_index and return_inverse is True.
        """
        _, indices, inverse_indices = np.unique(
            self.ravel_hash(coords), return_index=True, return_inverse=True)
        coords = coords[indices]

        outputs = []
        if return_index:
            outputs += [indices]
        if return_inverse:
            outputs += [inverse_indices]
        return outputs