"vscode:/vscode.git/clone" did not exist on "830fdd271536ee257db72c29c2be5b5629e58389"
data_preprocessor.py 22.8 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Union
5
6

import numpy as np
7
import torch
8
from mmdet.models import DetDataPreprocessor
9
from mmengine.model import stack_batch
10
from mmengine.utils import is_list_of
11
from torch.nn import functional as F
12
13

from mmdet3d.registry import MODELS
14
from mmdet3d.structures.det3d_data_sample import SampleList
15
from mmdet3d.utils import OptConfigType
16
from .utils import multiview_img_stack_batch
17
from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d
18

19
20
21

@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
22
23
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
24
25
26

    It provides the data pre-processing as follows

27
28
29
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
30
31
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
32
      ``pad_size_divisor``.
33
34
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
35
      (3, H, W).
36
    - Normalize images in inputs with defined std and mean.
37
38
39
    - Do batch augmentations during training.

    - 2) For point cloud data:
40
41
    - If no voxelization, directly return list of point cloud data.
    - If voxelization is applied, voxelize point cloud according to
42
      ``voxel_type`` and obtain ``voxels``.
43
44

    Args:
45
46
        voxel (bool): Whether to apply voxelization to point cloud.
            Defaults to False.
47
48
49
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
50
        voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
51
            config. Defaults to None.
52
53
        max_voxels (int): Maximum number of voxels in each voxel grid. Defaults
            to None.
54
55
56
57
58
59
60
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
61
62
63
64
65
66
67
68
        pad_mask (bool): Whether to pad instance masks. Defaults to False.
        mask_pad_value (int): The padded pixel value for instance masks.
            Defaults to 0.
        pad_seg (bool): Whether to pad semantic segmentation maps.
            Defaults to False.
        seg_pad_value (int): The padded pixel value for semantic
            segmentation maps. Defaults to 255.
        bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
69
            Defaults to False.
70
        rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
71
            Defaults to False.
72
73
74
75
        boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
            bboxes data or not. Defaults to True.
        batch_augments (List[dict], optional): Batch-level augmentations.
            Defaults to None.
76
77
78
    """

    def __init__(self,
79
80
81
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
82
                 max_voxels: Optional[int] = None,
83
84
85
86
87
88
89
90
91
92
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
93
                 boxtype2tensor: bool = True,
94
                 batch_augments: Optional[List[dict]] = None) -> None:
Xiangxu-0103's avatar
Xiangxu-0103 committed
95
        super(Det3DDataPreprocessor, self).__init__(
96
97
98
99
100
101
102
103
104
105
106
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
            batch_augments=batch_augments)
107
108
        self.voxel = voxel
        self.voxel_type = voxel_type
109
        self.max_voxels = max_voxels
110
        if voxel:
111
            self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
112

113
114
115
116
    def forward(self,
                data: Union[dict, List[dict]],
                training: bool = False) -> Union[dict, List[dict]]:
        """Perform normalization, padding and bgr2rgb conversion based on
117
118
119
        ``BaseDataPreprocessor``.

        Args:
120
            data (dict or List[dict]): Data from dataloader.
121
122
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.
123
            training (bool): Whether to enable training time augmentation.
124
                Defaults to False.
125
126

        Returns:
127
            dict or List[dict]: Data in the same format as the model input.
128
        """
129
130
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
131
132
            aug_batch_data = []
            for aug_id in range(num_augs):
133
134
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
135
                aug_batch_data.append(single_aug_batch_data)
136
            return aug_batch_data
jshilong's avatar
jshilong committed
137
138
139
140

        else:
            return self.simple_process(data, training)

141
    def simple_process(self, data: dict, training: bool = False) -> dict:
142
        """Perform normalization, padding and bgr2rgb conversion for img data
143
144
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
145

146
147
148
149
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
150

151
152
153
154
155
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
156

157
158
159
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
160

161
162
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
163

164
            if self.voxel:
165
                voxel_dict = self.voxelize(inputs['points'], data_samples)
166
167
168
169
170
171
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
172
                # NOTE the batched image size information may be useful, e.g.
173
174
175
176
177
178
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
179
180
181
182
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

VVsssssk's avatar
VVsssssk committed
183
184
185
                if hasattr(self, 'boxtype2tensor') and self.boxtype2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxtype2tensor
186
                    samplelist_boxtype2tensor(data_samples)
VVsssssk's avatar
VVsssssk committed
187
188
189
190
                elif hasattr(self, 'boxlist2tensor') and self.boxlist2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxlist2tensor
                    samplelist_boxlist2tensor(data_samples)
191
                if self.pad_mask:
192
                    self.pad_gt_masks(data_samples)
193
194

                if self.pad_seg:
195
                    self.pad_gt_sem_seg(data_samples)
196
197
198

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
199
200
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
201

202
        return {'inputs': batch_inputs, 'data_samples': data_samples}
203

204
    def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

221
    def collate_data(self, data: dict) -> dict:
222
        """Copying data to the target device and Performs normalization,
223
224
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
225
226
227
228
229

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
230
            data (dict): Data sampled from dataloader.
231
232

        Returns:
233
            dict: Data in the same format as the model input.
234
        """
235
236
237
238
239
240
241
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
242
                img_dim = _batch_imgs[0].dim()
243
                for _batch_img in _batch_imgs:
244
245
246
247
248
249
250
251
252
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

253
                    batch_imgs.append(_batch_img)
254

255
                # Pad and stack Tensor.
256
257
258
259
260
261
262
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
289
                    f'{type(data)}: {data}')
290
291
292
293
294
295
296
297

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
298
299
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
300
        # rewrite `_get_pad_shape` for obtaining image inputs.
301
302
303
304
305
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
306
                if ori_input.dim() == 4:
307
                    # mean multiview input, select one of the
308
309
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
332
                            'or a tuple with inputs and data_samples, but got '
333
                            f'{type(data)}: {data}')
334
        return batch_pad_shape
335
336

    @torch.no_grad()
337
338
    def voxelize(self, points: List[torch.Tensor],
                 data_samples: SampleList) -> Dict[str, torch.Tensor]:
339
340
341
342
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.
343
344
            data_samples: (list[:obj:`Det3DDataSample`]): The annotation data
                of every samples. Add voxel-wise annotation for segmentation.
345
346

        Returns:
347
            Dict[str, Tensor]: Voxelization information.
348

349
350
351
352
            - voxels (Tensor): Features of voxels, shape is MxNxC for hard
              voxelization, NxC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
              where 1 represents the batch index.
353
354
355
356
357
358
359
360
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
361
            for i, res in enumerate(points):
362
363
364
365
366
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
367
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
368
369
370
371
372
373
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
374
            coors = torch.cat(coors, dim=0)
375
376
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
377

378
379
380
381
382
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
383
            for i, res in enumerate(points):
384
                res_coors = self.voxel_layer(res)
385
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
386
387
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
388
            coors = torch.cat(coors, dim=0)
389
390
391
392
393
394
395
396
397
398
        elif self.voxel_type == 'cylindrical':
            voxels, coors = [], []
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
                phi = torch.atan2(res[:, 1], res[:, 0])
                polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
                min_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[:3])
                max_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[3:])
399
400
401
402
403
404
405
406
407
408
409
410
                try:  # only support PyTorch >= 1.9.0
                    polar_res_clamp = torch.clamp(polar_res, min_bound,
                                                  max_bound)
                except TypeError:
                    polar_res_clamp = polar_res.clone()
                    for coor_idx in range(3):
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] >
                            max_bound[coor_idx]] = max_bound[coor_idx]
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] <
                            min_bound[coor_idx]] = min_bound[coor_idx]
411
                res_coors = torch.floor(
412
413
414
                    (polar_res_clamp - min_bound) / polar_res_clamp.new_tensor(
                        self.voxel_layer.voxel_size)).int()
                self.get_voxel_seg(res_coors, data_sample)
415
416
417
418
419
420
421
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
                res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
                                       dim=-1)
                voxels.append(res_voxels)
                coors.append(res_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)
422
423
424
425
426
427
428
429
        elif self.voxel_type == 'minkunet':
            voxels, coors = [], []
            voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                res_coors = torch.round(res[:, :3] / voxel_size).int()
                res_coors -= res_coors.min(0)[0]

                res_coors_numpy = res_coors.cpu().numpy()
430
                inds, point2voxel_map = self.sparse_quantize(
431
                    res_coors_numpy, return_index=True, return_inverse=True)
432
433
434
435
436
                point2voxel_map = torch.from_numpy(point2voxel_map).cuda()
                if self.training and self.max_voxels is not None:
                    if len(inds) > self.max_voxels:
                        inds = np.random.choice(
                            inds, self.max_voxels, replace=False)
437
                inds = torch.from_numpy(inds).cuda()
438
439
440
                if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'):
                    data_sample.gt_pts_seg.voxel_semantic_mask \
                        = data_sample.gt_pts_seg.pts_semantic_mask[inds]
441
442
443
444
                res_voxel_coors = res_coors[inds]
                res_voxels = res[inds]
                res_voxel_coors = F.pad(
                    res_voxel_coors, (0, 1), mode='constant', value=i)
445
                data_sample.point2voxel_map = point2voxel_map.long()
446
447
448
449
450
                voxels.append(res_voxels)
                coors.append(res_voxel_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)

451
452
453
454
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
455
        voxel_dict['coors'] = coors
456
457

        return voxel_dict
458
459
460
461
462
463
464
465
466

    def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
        """Get voxel-wise segmentation label and point2voxel map.

        Args:
            res_coors (Tensor): The voxel coordinates of points, Nx3.
            data_sample: (:obj:`Det3DDataSample`): The annotation data of
                every samples. Add voxel-wise annotation forsegmentation.
        """
467
468
469
470
471
472
473
474

        if self.training:
            pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
            voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
                F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
                True)
            voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
            data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
475
            data_sample.point2voxel_map = point2voxel_map
476
477
478
479
        else:
            pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
            _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
                                                       res_coors, 'mean', True)
480
            data_sample.point2voxel_map = point2voxel_map
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

    def ravel_hash(self, x: np.ndarray) -> np.ndarray:
        """Get voxel coordinates hash for np.unique().

        Args:
            x (np.ndarray): The voxel coordinates of points, Nx3.

        Returns:
            np.ndarray: Voxels coordinates hash.
        """
        assert x.ndim == 2, x.shape

        x = x - np.min(x, axis=0)
        x = x.astype(np.uint64, copy=False)
        xmax = np.max(x, axis=0).astype(np.uint64) + 1

        h = np.zeros(x.shape[0], dtype=np.uint64)
        for k in range(x.shape[1] - 1):
            h += x[:, k]
            h *= xmax[k + 1]
        h += x[:, -1]
        return h

    def sparse_quantize(self,
                        coords: np.ndarray,
                        return_index: bool = False,
                        return_inverse: bool = False) -> List[np.ndarray]:
        """Sparse Quantization for voxel coordinates used in Minkunet.

        Args:
            coords (np.ndarray): The voxel coordinates of points, Nx3.
            return_index (bool): Whether to return the indices of the
                unique coords, shape (M,).
            return_inverse (bool): Whether to return the indices of the
                original coords shape (N,).

        Returns:
            List[np.ndarray] or None: Return index and inverse map if
            return_index and return_inverse is True.
        """
        _, indices, inverse_indices = np.unique(
            self.ravel_hash(coords), return_index=True, return_inverse=True)
        coords = coords[indices]

        outputs = []
        if return_index:
            outputs += [indices]
        if return_inverse:
            outputs += [inverse_indices]
        return outputs