data_preprocessor.py 23.4 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Tuple, Union
5
6

import numpy as np
7
import torch
8
from mmdet.models import DetDataPreprocessor
9
from mmdet.models.utils.misc import samplelist_boxtype2tensor
10
from mmengine.model import stack_batch
11
12
from mmengine.utils import is_seq_of
from torch import Tensor
13
from torch.nn import functional as F
14
15

from mmdet3d.registry import MODELS
16
from mmdet3d.structures.det3d_data_sample import SampleList
17
from mmdet3d.utils import OptConfigType
18
from .utils import multiview_img_stack_batch
19
from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d
20

21
22
23

@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
24
25
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
26
27
28

    It provides the data pre-processing as follows

29
30
31
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
32
33
34
35
36
37
38
39
40

      - Pad images in inputs to the maximum size of current batch with defined
        ``pad_value``. The padding size can be divisible by a defined
        ``pad_size_divisor``.
      - Stack images in inputs to batch_imgs.
      - Convert images in inputs from bgr to rgb if the shape of input is
        (3, H, W).
      - Normalize images in inputs with defined std and mean.
      - Do batch augmentations during training.
41
42

    - 2) For point cloud data:
43
44
45
46

      - If no voxelization, directly return list of point cloud data.
      - If voxelization is applied, voxelize point cloud according to
        ``voxel_type`` and obtain ``voxels``.
47
48

    Args:
49
50
        voxel (bool): Whether to apply voxelization to point cloud.
            Defaults to False.
51
        voxel_type (str): Voxelization type. Two voxelization types are
52
53
            provided: 'hard' and 'dynamic', respectively for hard voxelization
            and dynamic voxelization. Defaults to 'hard'.
54
        voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
55
            config. Defaults to None.
56
57
        batch_first (bool): Whether to put the batch dimension to the first
            dimension when getting voxel coordinates. Defaults to True.
58
59
        max_voxels (int, optional): Maximum number of voxels in each voxel
            grid. Defaults to None.
60
61
62
63
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
64
65
66
        pad_size_divisor (int): The size of padded image should be divisible by
            ``pad_size_divisor``. Defaults to 1.
        pad_value (float or int): The padded pixel value. Defaults to 0.
67
68
69
70
71
        pad_mask (bool): Whether to pad instance masks. Defaults to False.
        mask_pad_value (int): The padded pixel value for instance masks.
            Defaults to 0.
        pad_seg (bool): Whether to pad semantic segmentation maps.
            Defaults to False.
72
73
        seg_pad_value (int): The padded pixel value for semantic segmentation
            maps. Defaults to 255.
74
        bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
75
            Defaults to False.
76
        rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
77
            Defaults to False.
78
79
80
81
        boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of
            bboxes data to ``Tensor`` type. Defaults to True.
        non_blocking (bool): Whether to block current process when transferring
            data to device. Defaults to False.
82
83
        batch_augments (List[dict], optional): Batch-level augmentations.
            Defaults to None.
84
85
86
    """

    def __init__(self,
87
88
89
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
90
                 batch_first: bool = True,
91
                 max_voxels: Optional[int] = None,
92
93
94
95
96
97
98
99
100
101
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
102
                 boxtype2tensor: bool = True,
103
                 non_blocking: bool = False,
104
                 batch_augments: Optional[List[dict]] = None) -> None:
Xiangxu-0103's avatar
Xiangxu-0103 committed
105
        super(Det3DDataPreprocessor, self).__init__(
106
107
108
109
110
111
112
113
114
115
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
116
117
            boxtype2tensor=boxtype2tensor,
            non_blocking=non_blocking,
118
            batch_augments=batch_augments)
119
120
        self.voxel = voxel
        self.voxel_type = voxel_type
121
        self.batch_first = batch_first
122
        self.max_voxels = max_voxels
123
        if voxel:
124
            self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
125

126
127
128
129
    def forward(self,
                data: Union[dict, List[dict]],
                training: bool = False) -> Union[dict, List[dict]]:
        """Perform normalization, padding and bgr2rgb conversion based on
130
131
132
        ``BaseDataPreprocessor``.

        Args:
133
134
135
            data (dict or List[dict]): Data from dataloader. The dict contains
                the whole batch data, when it is a list[dict], the list
                indicates test time augmentation.
136
            training (bool): Whether to enable training time augmentation.
137
                Defaults to False.
138
139

        Returns:
140
            dict or List[dict]: Data in the same format as the model input.
141
        """
142
143
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
144
145
            aug_batch_data = []
            for aug_id in range(num_augs):
146
147
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
148
                aug_batch_data.append(single_aug_batch_data)
149
            return aug_batch_data
jshilong's avatar
jshilong committed
150
151
152
153

        else:
            return self.simple_process(data, training)

154
    def simple_process(self, data: dict, training: bool = False) -> dict:
155
        """Perform normalization, padding and bgr2rgb conversion for img data
156
157
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
158

159
160
161
162
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
163

164
165
166
167
168
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
169

170
171
172
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
173

174
175
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
176

177
            if self.voxel:
178
                voxel_dict = self.voxelize(inputs['points'], data_samples)
179
180
181
182
183
184
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
185
                # NOTE the batched image size information may be useful, e.g.
186
187
188
189
190
191
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
192
193
194
195
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

196
                if self.boxtype2tensor:
197
                    samplelist_boxtype2tensor(data_samples)
198
                if self.pad_mask:
199
                    self.pad_gt_masks(data_samples)
200
                if self.pad_seg:
201
                    self.pad_gt_sem_seg(data_samples)
202
203
204

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
205
206
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
207

208
        return {'inputs': batch_inputs, 'data_samples': data_samples}
209

210
    def preprocess_img(self, _batch_img: Tensor) -> Tensor:
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

227
    def collate_data(self, data: dict) -> dict:
228
229
        """Copy data to the target device and perform normalization, padding
        and bgr2rgb conversion and stack based on ``BaseDataPreprocessor``.
230

231
232
        Collates the data sampled from dataloader into a list of dict and list
        of labels, and then copies tensor to the target device.
233
234

        Args:
235
            data (dict): Data sampled from dataloader.
236
237

        Returns:
238
            dict: Data in the same format as the model input.
239
        """
240
241
242
243
244
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
245
            if is_seq_of(_batch_imgs, torch.Tensor):
246
                batch_imgs = []
247
                img_dim = _batch_imgs[0].dim()
248
                for _batch_img in _batch_imgs:
249
250
251
252
253
254
255
256
257
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

258
                    batch_imgs.append(_batch_img)
259

260
                # Pad and stack Tensor.
261
262
263
264
265
266
267
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
293
                    'or a tuple with inputs and data_samples, but got '
294
                    f'{type(data)}: {data}')
295
296
297
298
299
300
301

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

302
    def _get_pad_shape(self, data: dict) -> List[Tuple[int, int]]:
303
304
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
305
        # rewrite `_get_pad_shape` for obtaining image inputs.
306
307
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
308
        if is_seq_of(_batch_inputs, torch.Tensor):
309
310
            batch_pad_shape = []
            for ori_input in _batch_inputs:
311
                if ori_input.dim() == 4:
312
                    # mean multiview input, select one of the
313
314
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
337
                            'or a tuple with inputs and data_samples, but got '
338
                            f'{type(data)}: {data}')
339
        return batch_pad_shape
340
341

    @torch.no_grad()
342
343
    def voxelize(self, points: List[Tensor],
                 data_samples: SampleList) -> Dict[str, Tensor]:
344
345
346
347
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.
348
349
            data_samples: (list[:obj:`Det3DDataSample`]): The annotation data
                of every samples. Add voxel-wise annotation for segmentation.
350
351

        Returns:
352
            Dict[str, Tensor]: Voxelization information.
353

354
355
356
357
            - voxels (Tensor): Features of voxels, shape is MxNxC for hard
              voxelization, NxC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
              where 1 represents the batch index.
358
359
360
361
362
363
364
365
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
366
            for i, res in enumerate(points):
367
368
369
370
371
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
372
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
373
374
375
376
377
378
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
379
            coors = torch.cat(coors, dim=0)
380
381
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
382

383
384
385
386
387
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
388
            for i, res in enumerate(points):
389
                res_coors = self.voxel_layer(res)
390
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
391
392
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
393
            coors = torch.cat(coors, dim=0)
394
395
396
397
398
399
400
401
402
403
        elif self.voxel_type == 'cylindrical':
            voxels, coors = [], []
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
                phi = torch.atan2(res[:, 1], res[:, 0])
                polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
                min_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[:3])
                max_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[3:])
404
405
406
407
408
409
410
411
412
413
414
415
                try:  # only support PyTorch >= 1.9.0
                    polar_res_clamp = torch.clamp(polar_res, min_bound,
                                                  max_bound)
                except TypeError:
                    polar_res_clamp = polar_res.clone()
                    for coor_idx in range(3):
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] >
                            max_bound[coor_idx]] = max_bound[coor_idx]
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] <
                            min_bound[coor_idx]] = min_bound[coor_idx]
416
                res_coors = torch.floor(
417
418
419
                    (polar_res_clamp - min_bound) / polar_res_clamp.new_tensor(
                        self.voxel_layer.voxel_size)).int()
                self.get_voxel_seg(res_coors, data_sample)
420
421
422
423
424
425
426
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
                res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
                                       dim=-1)
                voxels.append(res_voxels)
                coors.append(res_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)
427
428
429
430
431
432
433
434
        elif self.voxel_type == 'minkunet':
            voxels, coors = [], []
            voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                res_coors = torch.round(res[:, :3] / voxel_size).int()
                res_coors -= res_coors.min(0)[0]

                res_coors_numpy = res_coors.cpu().numpy()
435
                inds, point2voxel_map = self.sparse_quantize(
436
                    res_coors_numpy, return_index=True, return_inverse=True)
437
438
439
440
441
                point2voxel_map = torch.from_numpy(point2voxel_map).cuda()
                if self.training and self.max_voxels is not None:
                    if len(inds) > self.max_voxels:
                        inds = np.random.choice(
                            inds, self.max_voxels, replace=False)
442
                inds = torch.from_numpy(inds).cuda()
443
444
445
                if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'):
                    data_sample.gt_pts_seg.voxel_semantic_mask \
                        = data_sample.gt_pts_seg.pts_semantic_mask[inds]
446
447
                res_voxel_coors = res_coors[inds]
                res_voxels = res[inds]
448
449
450
451
452
453
454
455
                if self.batch_first:
                    res_voxel_coors = F.pad(
                        res_voxel_coors, (1, 0), mode='constant', value=i)
                    data_sample.batch_idx = res_voxel_coors[:, 0]
                else:
                    res_voxel_coors = F.pad(
                        res_voxel_coors, (0, 1), mode='constant', value=i)
                    data_sample.batch_idx = res_voxel_coors[:, -1]
456
                data_sample.point2voxel_map = point2voxel_map.long()
457
458
459
460
461
                voxels.append(res_voxels)
                coors.append(res_voxel_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)

462
463
464
465
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
466
        voxel_dict['coors'] = coors
467
468

        return voxel_dict
469

470
471
    def get_voxel_seg(self, res_coors: Tensor,
                      data_sample: SampleList) -> None:
472
473
474
475
476
477
478
        """Get voxel-wise segmentation label and point2voxel map.

        Args:
            res_coors (Tensor): The voxel coordinates of points, Nx3.
            data_sample: (:obj:`Det3DDataSample`): The annotation data of
                every samples. Add voxel-wise annotation forsegmentation.
        """
479
480
481
482
483
484
485
486

        if self.training:
            pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
            voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
                F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
                True)
            voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
            data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
487
            data_sample.point2voxel_map = point2voxel_map
488
489
490
491
        else:
            pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
            _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
                                                       res_coors, 'mean', True)
492
            data_sample.point2voxel_map = point2voxel_map
493
494

    def ravel_hash(self, x: np.ndarray) -> np.ndarray:
495
        """Get voxel coordinates hash for np.unique.
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

        Args:
            x (np.ndarray): The voxel coordinates of points, Nx3.

        Returns:
            np.ndarray: Voxels coordinates hash.
        """
        assert x.ndim == 2, x.shape

        x = x - np.min(x, axis=0)
        x = x.astype(np.uint64, copy=False)
        xmax = np.max(x, axis=0).astype(np.uint64) + 1

        h = np.zeros(x.shape[0], dtype=np.uint64)
        for k in range(x.shape[1] - 1):
            h += x[:, k]
            h *= xmax[k + 1]
        h += x[:, -1]
        return h

    def sparse_quantize(self,
                        coords: np.ndarray,
                        return_index: bool = False,
                        return_inverse: bool = False) -> List[np.ndarray]:
        """Sparse Quantization for voxel coordinates used in Minkunet.

        Args:
            coords (np.ndarray): The voxel coordinates of points, Nx3.
524
525
            return_index (bool): Whether to return the indices of the unique
                coords, shape (M,).
526
            return_inverse (bool): Whether to return the indices of the
527
                original coords, shape (N,).
528
529

        Returns:
530
531
            List[np.ndarray]: Return index and inverse map if return_index and
            return_inverse is True.
532
533
534
535
536
537
538
539
540
541
542
        """
        _, indices, inverse_indices = np.unique(
            self.ravel_hash(coords), return_index=True, return_inverse=True)
        coords = coords[indices]

        outputs = []
        if return_index:
            outputs += [indices]
        if return_inverse:
            outputs += [inverse_indices]
        return outputs