Commit 2072a9df authored by Xiangxu-0103's avatar Xiangxu-0103 Committed by ZwwWayne
Browse files

[Refactor] Refactor `voxelization` for faster speed (#2062)

* refactor voxelization for faster speed

* fix doc typo
parent 13ba0dca
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from numbers import Number from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Union
import numpy as np import numpy as np
import torch import torch
...@@ -28,24 +28,25 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -28,24 +28,25 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
- 1) For image data: - 1) For image data:
- Pad images in inputs to the maximum size of current batch with defined - Pad images in inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined ``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor`` ``pad_size_divisor``.
- Stack images in inputs to batch_imgs. - Stack images in inputs to batch_imgs.
- Convert images in inputs from bgr to rgb if the shape of input is - Convert images in inputs from bgr to rgb if the shape of input is
(3, H, W). (3, H, W).
- Normalize images in inputs with defined std and mean. - Normalize images in inputs with defined std and mean.
- Do batch augmentations during training. - Do batch augmentations during training.
- 2) For point cloud data: - 2) For point cloud data:
- if no voxelization, directly return list of point cloud data. - If no voxelization, directly return list of point cloud data.
- if voxelization is applied, voxelize point cloud according to - If voxelization is applied, voxelize point cloud according to
``voxel_type`` and obtain ``voxels``. ``voxel_type`` and obtain ``voxels``.
Args: Args:
voxel (bool): Whether to apply voxelziation to point cloud. voxel (bool): Whether to apply voxelization to point cloud.
Defaults to False.
voxel_type (str): Voxelization type. Two voxelization types are voxel_type (str): Voxelization type. Two voxelization types are
provided: 'hard' and 'dynamic', respectively for hard provided: 'hard' and 'dynamic', respectively for hard
voxelization and dynamic voxelization. Defaults to 'hard'. voxelization and dynamic voxelization. Defaults to 'hard'.
voxel_layer (:obj:`ConfigDict`, optional): Voxelization layer voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
config. Defaults to None. config. Defaults to None.
mean (Sequence[Number], optional): The pixel mean of R, G, B channels. mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None. Defaults to None.
...@@ -54,11 +55,21 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -54,11 +55,21 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
pad_size_divisor (int): The size of padded image should be pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1. divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0. pad_value (Number): The padded pixel value. Defaults to 0.
bgr_to_rgb (bool): whether to convert image from BGR to RGB. pad_mask (bool): Whether to pad instance masks. Defaults to False.
mask_pad_value (int): The padded pixel value for instance masks.
Defaults to 0.
pad_seg (bool): Whether to pad semantic segmentation maps.
Defaults to False.
seg_pad_value (int): The padded pixel value for semantic
segmentation maps. Defaults to 255.
bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
Defaults to False. Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB. rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
Defaults to False. Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
bboxes data or not. Defaults to True.
batch_augments (List[dict], optional): Batch-level augmentations.
Defaults to None.
""" """
def __init__(self, def __init__(self,
...@@ -76,8 +87,8 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -76,8 +87,8 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
bgr_to_rgb: bool = False, bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False, rgb_to_bgr: bool = False,
boxtype2tensor: bool = True, boxtype2tensor: bool = True,
batch_augments: Optional[List[dict]] = None): batch_augments: Optional[List[dict]] = None) -> None:
super().__init__( super(Det3DDataPreprocessor).__init__(
mean=mean, mean=mean,
std=std, std=std,
pad_size_divisor=pad_size_divisor, pad_size_divisor=pad_size_divisor,
...@@ -94,24 +105,21 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -94,24 +105,21 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
if voxel: if voxel:
self.voxel_layer = Voxelization(**voxel_layer) self.voxel_layer = Voxelization(**voxel_layer)
def forward( def forward(self,
self, data: Union[dict, List[dict]],
data: Union[dict, List[dict]], training: bool = False) -> Union[dict, List[dict]]:
training: bool = False """Perform normalization, padding and bgr2rgb conversion based on
) -> Tuple[Union[dict, List[dict]], Optional[list]]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``. ``BaseDataPreprocessor``.
Args: Args:
data (dict | List[dict]): data from dataloader. data (dict or List[dict]): Data from dataloader.
The dict contains the whole batch data, when it is The dict contains the whole batch data, when it is
a list[dict], the list indicate test time augmentation. a list[dict], the list indicate test time augmentation.
training (bool): Whether to enable training time augmentation. training (bool): Whether to enable training time augmentation.
Defaults to False. Defaults to False.
Returns: Returns:
Dict | List[Dict]: Data in the same format as the model input. dict or List[dict]: Data in the same format as the model input.
""" """
if isinstance(data, list): if isinstance(data, list):
num_augs = len(data) num_augs = len(data)
...@@ -126,7 +134,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -126,7 +134,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
return self.simple_process(data, training) return self.simple_process(data, training)
def simple_process(self, data: dict, training: bool = False) -> dict: def simple_process(self, data: dict, training: bool = False) -> dict:
"""Perform normalizationpadding and bgr2rgb conversion for img data """Perform normalization, padding and bgr2rgb conversion for img data
based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel` based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
is set to be True. is set to be True.
...@@ -188,7 +196,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -188,7 +196,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
return {'inputs': batch_inputs, 'data_samples': data_samples} return {'inputs': batch_inputs, 'data_samples': data_samples}
def preprocess_img(self, _batch_img): def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
# channel transform # channel transform
if self._channel_conversion: if self._channel_conversion:
_batch_img = _batch_img[[2, 1, 0], ...] _batch_img = _batch_img[[2, 1, 0], ...]
...@@ -206,7 +214,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -206,7 +214,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
return _batch_img return _batch_img
def collate_data(self, data: dict) -> dict: def collate_data(self, data: dict) -> dict:
"""Copying data to the target device and Performs normalization """Copying data to the target device and Performs normalization,
padding and bgr2rgb conversion and stack based on padding and bgr2rgb conversion and stack based on
``BaseDataPreprocessor``. ``BaseDataPreprocessor``.
...@@ -273,7 +281,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -273,7 +281,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
raise TypeError( raise TypeError(
'Output of `cast_data` should be a list of dict ' 'Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got' 'or a tuple with inputs and data_samples, but got'
f'{type(data)} {data}') f'{type(data)}: {data}')
data['inputs']['imgs'] = batch_imgs data['inputs']['imgs'] = batch_imgs
...@@ -284,14 +292,14 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -284,14 +292,14 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
def _get_pad_shape(self, data: dict) -> List[tuple]: def _get_pad_shape(self, data: dict) -> List[tuple]:
"""Get the pad_shape of each image based on data and """Get the pad_shape of each image based on data and
pad_size_divisor.""" pad_size_divisor."""
# rewrite `_get_pad_shape` for obaining image inputs. # rewrite `_get_pad_shape` for obtaining image inputs.
_batch_inputs = data['inputs']['img'] _batch_inputs = data['inputs']['img']
# Process data with `pseudo_collate`. # Process data with `pseudo_collate`.
if is_list_of(_batch_inputs, torch.Tensor): if is_list_of(_batch_inputs, torch.Tensor):
batch_pad_shape = [] batch_pad_shape = []
for ori_input in _batch_inputs: for ori_input in _batch_inputs:
if ori_input.dim() == 4: if ori_input.dim() == 4:
# mean multiivew input, select ont of the # mean multiview input, select one of the
# image to calculate the pad shape # image to calculate the pad shape
ori_input = ori_input[0] ori_input = ori_input[0]
pad_h = int( pad_h = int(
...@@ -316,24 +324,24 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -316,24 +324,24 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0] batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
else: else:
raise TypeError('Output of `cast_data` should be a list of dict ' raise TypeError('Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got' 'or a tuple with inputs and data_samples, but got '
f'{type(data)}: {data}') f'{type(data)}: {data}')
return batch_pad_shape return batch_pad_shape
@torch.no_grad() @torch.no_grad()
def voxelize(self, points: List[torch.Tensor]) -> Dict: def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Apply voxelization to point cloud. """Apply voxelization to point cloud.
Args: Args:
points (List[Tensor]): Point cloud in one data batch. points (List[Tensor]): Point cloud in one data batch.
Returns: Returns:
dict[str, Tensor]: Voxelization information. Dict[str, Tensor]: Voxelization information.
- voxels (Tensor): Features of voxels, shape is MXNxC for hard - voxels (Tensor): Features of voxels, shape is MxNxC for hard
voxelization, NXC for dynamic voxelization. voxelization, NxC for dynamic voxelization.
- coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim), - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
where 1 represents the batch index. where 1 represents the batch index.
- num_points (Tensor, optional): Number of points in each voxel. - num_points (Tensor, optional): Number of points in each voxel.
- voxel_centers (Tensor, optional): Centers of voxels. - voxel_centers (Tensor, optional): Centers of voxels.
""" """
...@@ -342,43 +350,38 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -342,43 +350,38 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
if self.voxel_type == 'hard': if self.voxel_type == 'hard':
voxels, coors, num_points, voxel_centers = [], [], [], [] voxels, coors, num_points, voxel_centers = [], [], [], []
for res in points: for i, res in enumerate(points):
res_voxels, res_coors, res_num_points = self.voxel_layer(res) res_voxels, res_coors, res_num_points = self.voxel_layer(res)
res_voxel_centers = ( res_voxel_centers = (
res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor( res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
self.voxel_layer.voxel_size) + res_voxels.new_tensor( self.voxel_layer.voxel_size) + res_voxels.new_tensor(
self.voxel_layer.point_cloud_range[0:3]) self.voxel_layer.point_cloud_range[0:3])
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
voxels.append(res_voxels) voxels.append(res_voxels)
coors.append(res_coors) coors.append(res_coors)
num_points.append(res_num_points) num_points.append(res_num_points)
voxel_centers.append(res_voxel_centers) voxel_centers.append(res_voxel_centers)
voxels = torch.cat(voxels, dim=0) voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
num_points = torch.cat(num_points, dim=0) num_points = torch.cat(num_points, dim=0)
voxel_centers = torch.cat(voxel_centers, dim=0) voxel_centers = torch.cat(voxel_centers, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
voxel_dict['num_points'] = num_points voxel_dict['num_points'] = num_points
voxel_dict['voxel_centers'] = voxel_centers voxel_dict['voxel_centers'] = voxel_centers
elif self.voxel_type == 'dynamic': elif self.voxel_type == 'dynamic':
coors = [] coors = []
# dynamic voxelization only provide a coors mapping # dynamic voxelization only provide a coors mapping
for res in points: for i, res in enumerate(points):
res_coors = self.voxel_layer(res) res_coors = self.voxel_layer(res)
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
coors.append(res_coors) coors.append(res_coors)
voxels = torch.cat(points, dim=0) voxels = torch.cat(points, dim=0)
coors_batch = [] coors = torch.cat(coors, dim=0)
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
else: else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}') raise ValueError(f'Invalid voxelization type {self.voxel_type}')
voxel_dict['voxels'] = voxels voxel_dict['voxels'] = voxels
voxel_dict['coors'] = coors_batch voxel_dict['coors'] = coors
return voxel_dict return voxel_dict
...@@ -12,7 +12,7 @@ def multiview_img_stack_batch( ...@@ -12,7 +12,7 @@ def multiview_img_stack_batch(
""" """
Compared to the stack_batch in mmengine.model.utils, Compared to the stack_batch in mmengine.model.utils,
multiview_img_stack_batch further handle the multiview images. multiview_img_stack_batch further handle the multiview images.
see diff of padded_sizes[:, :-2] = 0 vs padded_sizees[:, 0] = 0 in line 47 see diff of padded_sizes[:, :-2] = 0 vs padded_sizes[:, 0] = 0 in line 47
Stack multiple tensors to form a batch and pad the tensor to the max Stack multiple tensors to form a batch and pad the tensor to the max
shape use the right bottom padding mode in these images. If shape use the right bottom padding mode in these images. If
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
...@@ -23,20 +23,20 @@ def multiview_img_stack_batch( ...@@ -23,20 +23,20 @@ def multiview_img_stack_batch(
pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
to ensure the shape of each dim is divisible by to ensure the shape of each dim is divisible by
``pad_size_divisor``. This depends on the model, and many ``pad_size_divisor``. This depends on the model, and many
models need to be divisible by 32. Defaults to 1 models need to be divisible by 32. Defaults to 1.
pad_value (int, float): The padding value. Defaults to 0. pad_value (int or float): The padding value. Defaults to 0.
Returns: Returns:
Tensor: The n dim tensor. Tensor: The n dim tensor.
""" """
assert isinstance( assert isinstance(
tensor_list, tensor_list,
list), (f'Expected input type to be list, but got {type(tensor_list)}') list), f'Expected input type to be list, but got {type(tensor_list)}'
assert tensor_list, '`tensor_list` could not be an empty list' assert tensor_list, '`tensor_list` could not be an empty list'
assert len({ assert len({
tensor.ndim tensor.ndim
for tensor in tensor_list for tensor in tensor_list
}) == 1, (f'Expected the dimensions of all tensors must be the same, ' }) == 1, ('Expected the dimensions of all tensors must be the same, '
f'but got {[tensor.ndim for tensor in tensor_list]}') f'but got {[tensor.ndim for tensor in tensor_list]}')
dim = tensor_list[0].dim() dim = tensor_list[0].dim()
...@@ -46,7 +46,7 @@ def multiview_img_stack_batch( ...@@ -46,7 +46,7 @@ def multiview_img_stack_batch(
max_sizes = torch.ceil( max_sizes = torch.ceil(
torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor
padded_sizes = max_sizes - all_sizes padded_sizes = max_sizes - all_sizes
# The first dim normally means channel, which should not be padded. # The first dim normally means channel, which should not be padded.
padded_sizes[:, :-2] = 0 padded_sizes[:, :-2] = 0
if padded_sizes.sum() == 0: if padded_sizes.sum() == 0:
return torch.stack(tensor_list) return torch.stack(tensor_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment