data_preprocessor.py 23.4 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import math
3
from numbers import Number
4
from typing import Dict, List, Optional, Sequence, Union
5
6

import numpy as np
7
import torch
8
from mmdet.models import DetDataPreprocessor
9
from mmengine.model import stack_batch
10
from mmengine.utils import is_list_of
11
from torch.nn import functional as F
12
13

from mmdet3d.registry import MODELS
14
from mmdet3d.structures.det3d_data_sample import SampleList
15
from mmdet3d.utils import OptConfigType
16
from .utils import multiview_img_stack_batch
17
from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d
18

19
20
21

@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
22
23
    """Points / Image pre-processor for point clouds / vision-only / multi-
    modality 3D detection tasks.
24
25
26

    It provides the data pre-processing as follows

27
28
29
    - Collate and move image and point cloud data to the target device.

    - 1) For image data:
30
31
    - Pad images in inputs to the maximum size of current batch with defined
      ``pad_value``. The padding size can be divisible by a defined
32
      ``pad_size_divisor``.
33
34
    - Stack images in inputs to batch_imgs.
    - Convert images in inputs from bgr to rgb if the shape of input is
35
      (3, H, W).
36
    - Normalize images in inputs with defined std and mean.
37
38
39
    - Do batch augmentations during training.

    - 2) For point cloud data:
40
41
    - If no voxelization, directly return list of point cloud data.
    - If voxelization is applied, voxelize point cloud according to
42
      ``voxel_type`` and obtain ``voxels``.
43
44

    Args:
45
46
        voxel (bool): Whether to apply voxelization to point cloud.
            Defaults to False.
47
48
49
        voxel_type (str): Voxelization type. Two voxelization types are
            provided: 'hard' and 'dynamic', respectively for hard
            voxelization and dynamic voxelization. Defaults to 'hard'.
50
        voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
51
            config. Defaults to None.
52
53
        batch_first (bool): Whether to put the batch dimension to the first
            dimension when getting voxel coordinates. Defaults to True.
54
55
        max_voxels (int): Maximum number of voxels in each voxel grid. Defaults
            to None.
56
57
58
59
60
61
62
        mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
            Defaults to None.
        std (Sequence[Number], optional): The pixel standard deviation of
            R, G, B channels. Defaults to None.
        pad_size_divisor (int): The size of padded image should be
            divisible by ``pad_size_divisor``. Defaults to 1.
        pad_value (Number): The padded pixel value. Defaults to 0.
63
64
65
66
67
68
69
70
        pad_mask (bool): Whether to pad instance masks. Defaults to False.
        mask_pad_value (int): The padded pixel value for instance masks.
            Defaults to 0.
        pad_seg (bool): Whether to pad semantic segmentation maps.
            Defaults to False.
        seg_pad_value (int): The padded pixel value for semantic
            segmentation maps. Defaults to 255.
        bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
71
            Defaults to False.
72
        rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
73
            Defaults to False.
74
75
76
77
        boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
            bboxes data or not. Defaults to True.
        batch_augments (List[dict], optional): Batch-level augmentations.
            Defaults to None.
78
79
80
    """

    def __init__(self,
81
82
83
                 voxel: bool = False,
                 voxel_type: str = 'hard',
                 voxel_layer: OptConfigType = None,
84
                 batch_first: bool = True,
85
                 max_voxels: Optional[int] = None,
86
87
88
89
90
91
92
93
94
95
                 mean: Sequence[Number] = None,
                 std: Sequence[Number] = None,
                 pad_size_divisor: int = 1,
                 pad_value: Union[float, int] = 0,
                 pad_mask: bool = False,
                 mask_pad_value: int = 0,
                 pad_seg: bool = False,
                 seg_pad_value: int = 255,
                 bgr_to_rgb: bool = False,
                 rgb_to_bgr: bool = False,
96
                 boxtype2tensor: bool = True,
97
                 batch_augments: Optional[List[dict]] = None) -> None:
Xiangxu-0103's avatar
Xiangxu-0103 committed
98
        super(Det3DDataPreprocessor, self).__init__(
99
100
101
102
103
104
105
106
107
108
109
            mean=mean,
            std=std,
            pad_size_divisor=pad_size_divisor,
            pad_value=pad_value,
            pad_mask=pad_mask,
            mask_pad_value=mask_pad_value,
            pad_seg=pad_seg,
            seg_pad_value=seg_pad_value,
            bgr_to_rgb=bgr_to_rgb,
            rgb_to_bgr=rgb_to_bgr,
            batch_augments=batch_augments)
110
111
        self.voxel = voxel
        self.voxel_type = voxel_type
112
        self.batch_first = batch_first
113
        self.max_voxels = max_voxels
114
        if voxel:
115
            self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
116

117
118
119
120
    def forward(self,
                data: Union[dict, List[dict]],
                training: bool = False) -> Union[dict, List[dict]]:
        """Perform normalization, padding and bgr2rgb conversion based on
121
122
123
        ``BaseDataPreprocessor``.

        Args:
124
            data (dict or List[dict]): Data from dataloader.
125
126
                The dict contains the whole batch data, when it is
                a list[dict], the list indicate test time augmentation.
127
            training (bool): Whether to enable training time augmentation.
128
                Defaults to False.
129
130

        Returns:
131
            dict or List[dict]: Data in the same format as the model input.
132
        """
133
134
        if isinstance(data, list):
            num_augs = len(data)
jshilong's avatar
jshilong committed
135
136
            aug_batch_data = []
            for aug_id in range(num_augs):
137
138
                single_aug_batch_data = self.simple_process(
                    data[aug_id], training)
jshilong's avatar
jshilong committed
139
                aug_batch_data.append(single_aug_batch_data)
140
            return aug_batch_data
jshilong's avatar
jshilong committed
141
142
143
144

        else:
            return self.simple_process(data, training)

145
    def simple_process(self, data: dict, training: bool = False) -> dict:
146
        """Perform normalization, padding and bgr2rgb conversion for img data
147
148
        based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
        is set to be True.
149

150
151
152
153
        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.
                Defaults to False.
154

155
156
157
158
159
        Returns:
            dict: Data in the same format as the model input.
        """
        if 'img' in data['inputs']:
            batch_pad_shape = self._get_pad_shape(data)
160

161
162
163
        data = self.collate_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        batch_inputs = dict()
164

165
166
        if 'points' in inputs:
            batch_inputs['points'] = inputs['points']
167

168
            if self.voxel:
169
                voxel_dict = self.voxelize(inputs['points'], data_samples)
170
171
172
173
174
175
                batch_inputs['voxels'] = voxel_dict

        if 'imgs' in inputs:
            imgs = inputs['imgs']

            if data_samples is not None:
176
                # NOTE the batched image size information may be useful, e.g.
177
178
179
180
181
182
                # in DETR, this is needed for the construction of masks, which
                # is then used for the transformer_head.
                batch_input_shape = tuple(imgs[0].size()[-2:])
                for data_sample, pad_shape in zip(data_samples,
                                                  batch_pad_shape):
                    data_sample.set_metainfo({
183
184
185
186
                        'batch_input_shape': batch_input_shape,
                        'pad_shape': pad_shape
                    })

VVsssssk's avatar
VVsssssk committed
187
188
189
                if hasattr(self, 'boxtype2tensor') and self.boxtype2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxtype2tensor
190
                    samplelist_boxtype2tensor(data_samples)
VVsssssk's avatar
VVsssssk committed
191
192
193
194
                elif hasattr(self, 'boxlist2tensor') and self.boxlist2tensor:
                    from mmdet.models.utils.misc import \
                        samplelist_boxlist2tensor
                    samplelist_boxlist2tensor(data_samples)
195
                if self.pad_mask:
196
                    self.pad_gt_masks(data_samples)
197
198

                if self.pad_seg:
199
                    self.pad_gt_sem_seg(data_samples)
200
201
202

            if training and self.batch_augments is not None:
                for batch_aug in self.batch_augments:
203
204
                    imgs, data_samples = batch_aug(imgs, data_samples)
            batch_inputs['imgs'] = imgs
205

206
        return {'inputs': batch_inputs, 'data_samples': data_samples}
207

208
    def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        # channel transform
        if self._channel_conversion:
            _batch_img = _batch_img[[2, 1, 0], ...]
        # Convert to float after channel conversion to ensure
        # efficiency
        _batch_img = _batch_img.float()
        # Normalization.
        if self._enable_normalize:
            if self.mean.shape[0] == 3:
                assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
                    'If the mean has 3 values, the input tensor '
                    'should in shape of (3, H, W), but got the '
                    f'tensor with shape {_batch_img.shape}')
            _batch_img = (_batch_img - self.mean) / self.std
        return _batch_img

225
    def collate_data(self, data: dict) -> dict:
226
        """Copying data to the target device and Performs normalization,
227
228
        padding and bgr2rgb conversion and stack based on
        ``BaseDataPreprocessor``.
229
230
231
232
233

        Collates the data sampled from dataloader into a list of dict and
        list of labels, and then copies tensor to the target device.

        Args:
234
            data (dict): Data sampled from dataloader.
235
236

        Returns:
237
            dict: Data in the same format as the model input.
238
        """
239
240
241
242
243
244
245
        data = self.cast_data(data)  # type: ignore

        if 'img' in data['inputs']:
            _batch_imgs = data['inputs']['img']
            # Process data with `pseudo_collate`.
            if is_list_of(_batch_imgs, torch.Tensor):
                batch_imgs = []
246
                img_dim = _batch_imgs[0].dim()
247
                for _batch_img in _batch_imgs:
248
249
250
251
252
253
254
255
256
                    if img_dim == 3:  # standard img
                        _batch_img = self.preprocess_img(_batch_img)
                    elif img_dim == 4:
                        _batch_img = [
                            self.preprocess_img(_img) for _img in _batch_img
                        ]

                        _batch_img = torch.stack(_batch_img, dim=0)

257
                    batch_imgs.append(_batch_img)
258

259
                # Pad and stack Tensor.
260
261
262
263
264
265
266
                if img_dim == 3:
                    batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
                                             self.pad_value)
                elif img_dim == 4:
                    batch_imgs = multiview_img_stack_batch(
                        batch_imgs, self.pad_size_divisor, self.pad_value)

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            # Process data with `default_collate`.
            elif isinstance(_batch_imgs, torch.Tensor):
                assert _batch_imgs.dim() == 4, (
                    'The input of `ImgDataPreprocessor` should be a NCHW '
                    'tensor or a list of tensor, but got a tensor with '
                    f'shape: {_batch_imgs.shape}')
                if self._channel_conversion:
                    _batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
                # Convert to float after channel conversion to ensure
                # efficiency
                _batch_imgs = _batch_imgs.float()
                if self._enable_normalize:
                    _batch_imgs = (_batch_imgs - self.mean) / self.std
                h, w = _batch_imgs.shape[2:]
                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
                                   'constant', self.pad_value)
            else:
                raise TypeError(
                    'Output of `cast_data` should be a list of dict '
                    'or a tuple with inputs and data_samples, but got'
293
                    f'{type(data)}: {data}')
294
295
296
297
298
299
300
301

            data['inputs']['imgs'] = batch_imgs

        data.setdefault('data_samples', None)

        return data

    def _get_pad_shape(self, data: dict) -> List[tuple]:
302
303
        """Get the pad_shape of each image based on data and
        pad_size_divisor."""
304
        # rewrite `_get_pad_shape` for obtaining image inputs.
305
306
307
308
309
        _batch_inputs = data['inputs']['img']
        # Process data with `pseudo_collate`.
        if is_list_of(_batch_inputs, torch.Tensor):
            batch_pad_shape = []
            for ori_input in _batch_inputs:
310
                if ori_input.dim() == 4:
311
                    # mean multiview input, select one of the
312
313
                    # image to calculate the pad shape
                    ori_input = ori_input[0]
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                pad_h = int(
                    np.ceil(ori_input.shape[1] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                pad_w = int(
                    np.ceil(ori_input.shape[2] /
                            self.pad_size_divisor)) * self.pad_size_divisor
                batch_pad_shape.append((pad_h, pad_w))
        # Process data with `default_collate`.
        elif isinstance(_batch_inputs, torch.Tensor):
            assert _batch_inputs.dim() == 4, (
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '
                'or a list of tensor, but got a tensor with shape: '
                f'{_batch_inputs.shape}')
            pad_h = int(
                np.ceil(_batch_inputs.shape[1] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            pad_w = int(
                np.ceil(_batch_inputs.shape[2] /
                        self.pad_size_divisor)) * self.pad_size_divisor
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
        else:
            raise TypeError('Output of `cast_data` should be a list of dict '
336
                            'or a tuple with inputs and data_samples, but got '
337
                            f'{type(data)}: {data}')
338
        return batch_pad_shape
339
340

    @torch.no_grad()
341
342
    def voxelize(self, points: List[torch.Tensor],
                 data_samples: SampleList) -> Dict[str, torch.Tensor]:
343
344
345
346
        """Apply voxelization to point cloud.

        Args:
            points (List[Tensor]): Point cloud in one data batch.
347
348
            data_samples: (list[:obj:`Det3DDataSample`]): The annotation data
                of every samples. Add voxel-wise annotation for segmentation.
349
350

        Returns:
351
            Dict[str, Tensor]: Voxelization information.
352

353
354
355
356
            - voxels (Tensor): Features of voxels, shape is MxNxC for hard
              voxelization, NxC for dynamic voxelization.
            - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
              where 1 represents the batch index.
357
358
359
360
361
362
363
364
            - num_points (Tensor, optional): Number of points in each voxel.
            - voxel_centers (Tensor, optional): Centers of voxels.
        """

        voxel_dict = dict()

        if self.voxel_type == 'hard':
            voxels, coors, num_points, voxel_centers = [], [], [], []
365
            for i, res in enumerate(points):
366
367
368
369
370
                res_voxels, res_coors, res_num_points = self.voxel_layer(res)
                res_voxel_centers = (
                    res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                        self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                            self.voxel_layer.point_cloud_range[0:3])
371
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
372
373
374
375
376
377
                voxels.append(res_voxels)
                coors.append(res_coors)
                num_points.append(res_num_points)
                voxel_centers.append(res_voxel_centers)

            voxels = torch.cat(voxels, dim=0)
378
            coors = torch.cat(coors, dim=0)
379
380
            num_points = torch.cat(num_points, dim=0)
            voxel_centers = torch.cat(voxel_centers, dim=0)
381

382
383
384
385
386
            voxel_dict['num_points'] = num_points
            voxel_dict['voxel_centers'] = voxel_centers
        elif self.voxel_type == 'dynamic':
            coors = []
            # dynamic voxelization only provide a coors mapping
387
            for i, res in enumerate(points):
388
                res_coors = self.voxel_layer(res)
389
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
390
391
                coors.append(res_coors)
            voxels = torch.cat(points, dim=0)
392
            coors = torch.cat(coors, dim=0)
393
394
395
396
397
398
399
400
401
402
        elif self.voxel_type == 'cylindrical':
            voxels, coors = [], []
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
                phi = torch.atan2(res[:, 1], res[:, 0])
                polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
                min_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[:3])
                max_bound = polar_res.new_tensor(
                    self.voxel_layer.point_cloud_range[3:])
403
404
405
406
407
408
409
410
411
412
413
414
                try:  # only support PyTorch >= 1.9.0
                    polar_res_clamp = torch.clamp(polar_res, min_bound,
                                                  max_bound)
                except TypeError:
                    polar_res_clamp = polar_res.clone()
                    for coor_idx in range(3):
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] >
                            max_bound[coor_idx]] = max_bound[coor_idx]
                        polar_res_clamp[:, coor_idx][
                            polar_res[:, coor_idx] <
                            min_bound[coor_idx]] = min_bound[coor_idx]
415
                res_coors = torch.floor(
416
417
418
                    (polar_res_clamp - min_bound) / polar_res_clamp.new_tensor(
                        self.voxel_layer.voxel_size)).int()
                self.get_voxel_seg(res_coors, data_sample)
419
420
421
422
423
424
425
                res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
                res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
                                       dim=-1)
                voxels.append(res_voxels)
                coors.append(res_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)
426
427
428
429
430
431
432
433
        elif self.voxel_type == 'minkunet':
            voxels, coors = [], []
            voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
            for i, (res, data_sample) in enumerate(zip(points, data_samples)):
                res_coors = torch.round(res[:, :3] / voxel_size).int()
                res_coors -= res_coors.min(0)[0]

                res_coors_numpy = res_coors.cpu().numpy()
434
                inds, point2voxel_map = self.sparse_quantize(
435
                    res_coors_numpy, return_index=True, return_inverse=True)
436
437
438
439
440
                point2voxel_map = torch.from_numpy(point2voxel_map).cuda()
                if self.training and self.max_voxels is not None:
                    if len(inds) > self.max_voxels:
                        inds = np.random.choice(
                            inds, self.max_voxels, replace=False)
441
                inds = torch.from_numpy(inds).cuda()
442
443
444
                if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'):
                    data_sample.gt_pts_seg.voxel_semantic_mask \
                        = data_sample.gt_pts_seg.pts_semantic_mask[inds]
445
446
                res_voxel_coors = res_coors[inds]
                res_voxels = res[inds]
447
448
449
450
451
452
453
454
                if self.batch_first:
                    res_voxel_coors = F.pad(
                        res_voxel_coors, (1, 0), mode='constant', value=i)
                    data_sample.batch_idx = res_voxel_coors[:, 0]
                else:
                    res_voxel_coors = F.pad(
                        res_voxel_coors, (0, 1), mode='constant', value=i)
                    data_sample.batch_idx = res_voxel_coors[:, -1]
455
                data_sample.point2voxel_map = point2voxel_map.long()
456
457
458
459
460
                voxels.append(res_voxels)
                coors.append(res_voxel_coors)
            voxels = torch.cat(voxels, dim=0)
            coors = torch.cat(coors, dim=0)

461
462
463
464
        else:
            raise ValueError(f'Invalid voxelization type {self.voxel_type}')

        voxel_dict['voxels'] = voxels
465
        voxel_dict['coors'] = coors
466
467

        return voxel_dict
468
469
470
471
472
473
474
475
476

    def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
        """Get voxel-wise segmentation label and point2voxel map.

        Args:
            res_coors (Tensor): The voxel coordinates of points, Nx3.
            data_sample: (:obj:`Det3DDataSample`): The annotation data of
                every samples. Add voxel-wise annotation forsegmentation.
        """
477
478
479
480
481
482
483
484

        if self.training:
            pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
            voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
                F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
                True)
            voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
            data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
485
            data_sample.point2voxel_map = point2voxel_map
486
487
488
489
        else:
            pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
            _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
                                                       res_coors, 'mean', True)
490
            data_sample.point2voxel_map = point2voxel_map
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540

    def ravel_hash(self, x: np.ndarray) -> np.ndarray:
        """Get voxel coordinates hash for np.unique().

        Args:
            x (np.ndarray): The voxel coordinates of points, Nx3.

        Returns:
            np.ndarray: Voxels coordinates hash.
        """
        assert x.ndim == 2, x.shape

        x = x - np.min(x, axis=0)
        x = x.astype(np.uint64, copy=False)
        xmax = np.max(x, axis=0).astype(np.uint64) + 1

        h = np.zeros(x.shape[0], dtype=np.uint64)
        for k in range(x.shape[1] - 1):
            h += x[:, k]
            h *= xmax[k + 1]
        h += x[:, -1]
        return h

    def sparse_quantize(self,
                        coords: np.ndarray,
                        return_index: bool = False,
                        return_inverse: bool = False) -> List[np.ndarray]:
        """Sparse Quantization for voxel coordinates used in Minkunet.

        Args:
            coords (np.ndarray): The voxel coordinates of points, Nx3.
            return_index (bool): Whether to return the indices of the
                unique coords, shape (M,).
            return_inverse (bool): Whether to return the indices of the
                original coords shape (N,).

        Returns:
            List[np.ndarray] or None: Return index and inverse map if
            return_index and return_inverse is True.
        """
        _, indices, inverse_indices = np.unique(
            self.ravel_hash(coords), return_index=True, return_inverse=True)
        coords = coords[indices]

        outputs = []
        if return_index:
            outputs += [indices]
        if return_inverse:
            outputs += [inverse_indices]
        return outputs