Unverified Commit 5412046b authored by Xiangxu-0103's avatar Xiangxu-0103 Committed by GitHub
Browse files

[Enhance]: Add typehints for dataset transforms and fix potential bug for `PointSample` (#1875)

* update dataset transforms

* update dbsampler docstring and add typehints

* add type hints and fix potential point sample bug

* fix lint

* fix

* fix
parent d8c9bc66
...@@ -32,7 +32,7 @@ class Compose: ...@@ -32,7 +32,7 @@ class Compose:
data (dict): A result dict contains the data to transform. data (dict): A result dict contains the data to transform.
Returns: Returns:
dict: Transformed data. dict: Transformed data.
""" """
for t in self.transforms: for t in self.transforms:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import os import os
import warnings from typing import List, Optional
import mmengine import mmengine
import numpy as np import numpy as np
...@@ -16,18 +16,19 @@ class BatchSampler: ...@@ -16,18 +16,19 @@ class BatchSampler:
Args: Args:
sample_list (list[dict]): List of samples. sample_list (list[dict]): List of samples.
name (str, optional): The category of samples. Default: None. name (str, optional): The category of samples. Defaults to None.
epoch (int, optional): Sampling epoch. Default: None. epoch (int, optional): Sampling epoch. Defaults to None.
shuffle (bool, optional): Whether to shuffle indices. Default: False. shuffle (bool, optional): Whether to shuffle indices.
drop_reminder (bool, optional): Drop reminder. Default: False. Defaults to False.
drop_reminder (bool, optional): Drop reminder. Defaults to False.
""" """
def __init__(self, def __init__(self,
sampled_list, sampled_list: List[dict],
name=None, name: Optional[str] = None,
epoch=None, epoch: Optional[int] = None,
shuffle=True, shuffle: bool = True,
drop_reminder=False): drop_reminder: bool = False) -> None:
self._sampled_list = sampled_list self._sampled_list = sampled_list
self._indices = np.arange(len(sampled_list)) self._indices = np.arange(len(sampled_list))
if shuffle: if shuffle:
...@@ -40,7 +41,7 @@ class BatchSampler: ...@@ -40,7 +41,7 @@ class BatchSampler:
self._epoch_counter = 0 self._epoch_counter = 0
self._drop_reminder = drop_reminder self._drop_reminder = drop_reminder
def _sample(self, num): def _sample(self, num: int) -> List[int]:
"""Sample specific number of ground truths and return indices. """Sample specific number of ground truths and return indices.
Args: Args:
...@@ -57,7 +58,7 @@ class BatchSampler: ...@@ -57,7 +58,7 @@ class BatchSampler:
self._idx += num self._idx += num
return ret return ret
def _reset(self): def _reset(self) -> None:
"""Reset the index of batchsampler to zero.""" """Reset the index of batchsampler to zero."""
assert self._name is not None assert self._name is not None
# print("reset", self._name) # print("reset", self._name)
...@@ -65,7 +66,7 @@ class BatchSampler: ...@@ -65,7 +66,7 @@ class BatchSampler:
np.random.shuffle(self._indices) np.random.shuffle(self._indices)
self._idx = 0 self._idx = 0
def sample(self, num): def sample(self, num: int) -> List[dict]:
"""Sample specific number of ground truths. """Sample specific number of ground truths.
Args: Args:
...@@ -88,24 +89,28 @@ class DataBaseSampler(object): ...@@ -88,24 +89,28 @@ class DataBaseSampler(object):
rate (float): Rate of actual sampled over maximum sampled number. rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value. prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers. sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Default: None. classes (list[str], optional): List of classes. Defaults to None.
points_loader(dict, optional): Config of points loader. Default: points_loader(dict, optional): Config of points loader. Defaults to
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3]) dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
""" """
def __init__(self, def __init__(self,
info_path, info_path: str,
data_root, data_root: str,
rate, rate: float,
prepare, prepare: dict,
sample_groups, sample_groups: dict,
classes=None, classes: Optional[List[str]] = None,
points_loader=dict( points_loader: dict = dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
coord_type='LIDAR', coord_type='LIDAR',
load_dim=4, load_dim=4,
use_dim=[0, 1, 2, 3]), use_dim=[0, 1, 2, 3]),
file_client_args=dict(backend='disk')): file_client_args: dict = dict(backend='disk')) -> None:
super().__init__() super().__init__()
self.data_root = data_root self.data_root = data_root
self.info_path = info_path self.info_path = info_path
...@@ -118,18 +123,9 @@ class DataBaseSampler(object): ...@@ -118,18 +123,9 @@ class DataBaseSampler(object):
self.file_client = mmengine.FileClient(**file_client_args) self.file_client = mmengine.FileClient(**file_client_args)
# load data base infos # load data base infos
if hasattr(self.file_client, 'get_local_path'): with self.file_client.get_local_path(info_path) as local_path:
with self.file_client.get_local_path(info_path) as local_path: # loading data from a file-like object needs file format
# loading data from a file-like object needs file format db_infos = mmengine.load(open(local_path, 'rb'), file_format='pkl')
db_infos = mmengine.load(
open(local_path, 'rb'), file_format='pkl')
else:
warnings.warn(
'The used MMCV version does not have get_local_path. '
f'We treat the {info_path} as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.')
db_infos = mmengine.load(info_path)
# filter database infos # filter database infos
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
...@@ -163,7 +159,7 @@ class DataBaseSampler(object): ...@@ -163,7 +159,7 @@ class DataBaseSampler(object):
# TODO: No group_sampling currently # TODO: No group_sampling currently
@staticmethod @staticmethod
def filter_by_difficulty(db_infos, removed_difficulty): def filter_by_difficulty(db_infos: dict, removed_difficulty: list) -> dict:
"""Filter ground truths by difficulties. """Filter ground truths by difficulties.
Args: Args:
...@@ -182,7 +178,7 @@ class DataBaseSampler(object): ...@@ -182,7 +178,7 @@ class DataBaseSampler(object):
return new_db_infos return new_db_infos
@staticmethod @staticmethod
def filter_by_min_points(db_infos, min_gt_points_dict): def filter_by_min_points(db_infos: dict, min_gt_points_dict: dict) -> dict:
"""Filter ground truths by number of points in the bbox. """Filter ground truths by number of points in the bbox.
Args: Args:
...@@ -203,12 +199,19 @@ class DataBaseSampler(object): ...@@ -203,12 +199,19 @@ class DataBaseSampler(object):
db_infos[name] = filtered_infos db_infos[name] = filtered_infos
return db_infos return db_infos
def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None): def sample_all(self,
gt_bboxes: np.ndarray,
gt_labels: np.ndarray,
img: Optional[np.ndarray] = None,
ground_plane: Optional[np.ndarray] = None) -> dict:
"""Sampling all categories of bboxes. """Sampling all categories of bboxes.
Args: Args:
gt_bboxes (np.ndarray): Ground truth bounding boxes. gt_bboxes (np.ndarray): Ground truth bounding boxes.
gt_labels (np.ndarray): Ground truth labels of boxes. gt_labels (np.ndarray): Ground truth labels of boxes.
img (np.ndarray, optional): Image array. Defaults to None.
ground_plane (np.ndarray, optional): Ground plane information.
Defaults to None.
Returns: Returns:
dict: Dict of sampled 'pseudo ground truths'. dict: Dict of sampled 'pseudo ground truths'.
...@@ -301,7 +304,10 @@ class DataBaseSampler(object): ...@@ -301,7 +304,10 @@ class DataBaseSampler(object):
return ret return ret
def sample_class_v2(self, name, num, gt_bboxes): def sample_class_v2(self,
name: str,
num: int,
gt_bboxes: np.ndarray) -> List[dict]:
"""Sampling specific categories of bounding boxes. """Sampling specific categories of bounding boxes.
Args: Args:
......
...@@ -63,16 +63,16 @@ class Pack3DDetInputs(BaseTransform): ...@@ -63,16 +63,16 @@ class Pack3DDetInputs(BaseTransform):
def __init__( def __init__(
self, self,
keys: dict, keys: tuple,
meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img', meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'depth2img', 'cam2img', 'pad_shape',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip', 'scale_factor', 'flip', 'pcd_horizontal_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'num_pts_feats', 'pcd_trans', 'sample_idx', 'img_norm_cfg', 'num_pts_feats', 'pcd_trans',
'pcd_scale_factor', 'pcd_rotation', 'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path', 'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat', 'transformation_3d_flow', 'trans_mat',
'affine_aug')): 'affine_aug')) -> None:
self.keys = keys self.keys = keys
self.meta_keys = meta_keys self.meta_keys = meta_keys
...@@ -99,7 +99,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -99,7 +99,7 @@ class Pack3DDetInputs(BaseTransform):
- img - img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info of - 'data_samples' (obj:`Det3DDataSample`): The annotation info of
the sample. the sample.
""" """
# augtest # augtest
if isinstance(results, list): if isinstance(results, list):
...@@ -116,7 +116,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -116,7 +116,7 @@ class Pack3DDetInputs(BaseTransform):
else: else:
raise NotImplementedError raise NotImplementedError
def pack_single_results(self, results): def pack_single_results(self, results: dict) -> dict:
"""Method to pack the single input data. when the value in this dict is """Method to pack the single input data. when the value in this dict is
a list, it usually is in Augmentations Testing. a list, it usually is in Augmentations Testing.
...@@ -132,7 +132,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -132,7 +132,7 @@ class Pack3DDetInputs(BaseTransform):
- points - points
- img - img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info - 'data_samples' (:obj:`Det3DDataSample`): The annotation info
of the sample. of the sample.
""" """
# Format 3D data # Format 3D data
...@@ -220,6 +220,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -220,6 +220,7 @@ class Pack3DDetInputs(BaseTransform):
return packed_results return packed_results
def __repr__(self) -> str: def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(keys={self.keys})' repr_str += f'(keys={self.keys})'
repr_str += f'(meta_keys={self.meta_keys})' repr_str += f'(meta_keys={self.meta_keys})'
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List from typing import List, Union
import mmcv import mmcv
import mmengine import mmengine
...@@ -13,7 +13,7 @@ from mmdet.datasets.transforms import LoadAnnotations ...@@ -13,7 +13,7 @@ from mmdet.datasets.transforms import LoadAnnotations
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class LoadMultiViewImageFromFiles(object): class LoadMultiViewImageFromFiles(BaseTransform):
"""Load multi channel images from a list of separate channel files. """Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames. Expects results['img_filename'] to be a list of filenames.
...@@ -25,11 +25,15 @@ class LoadMultiViewImageFromFiles(object): ...@@ -25,11 +25,15 @@ class LoadMultiViewImageFromFiles(object):
Defaults to 'unchanged'. Defaults to 'unchanged'.
""" """
def __init__(self, to_float32=False, color_type='unchanged'): def __init__(
self,
to_float32: bool = False,
color_type: str = 'unchanged'
) -> None:
self.to_float32 = to_float32 self.to_float32 = to_float32
self.color_type = color_type self.color_type = color_type
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to load multi-view image from files. """Call function to load multi-view image from files.
Args: Args:
...@@ -139,7 +143,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -139,7 +143,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to [0, 1, 2, 4]. Defaults to [0, 1, 2, 4].
file_client_args (dict, optional): Config dict of file clients, file_client_args (dict, optional): Config dict of file clients,
refer to refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool, optional): Whether to repeat keyframe when pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
sweeps is empty. Defaults to False. sweeps is empty. Defaults to False.
...@@ -150,14 +154,16 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -150,14 +154,16 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to False. Defaults to False.
""" """
def __init__(self, def __init__(
sweeps_num=10, self,
load_dim=5, sweeps_num: int = 10,
use_dim=[0, 1, 2, 4], load_dim: int = 5,
file_client_args=dict(backend='disk'), use_dim: List[int] = [0, 1, 2, 4],
pad_empty_sweeps=False, file_client_args: dict = dict(backend='disk'),
remove_close=False, pad_empty_sweeps: bool = False,
test_mode=False): remove_close: bool = False,
test_mode: bool = False
) -> None:
self.load_dim = load_dim self.load_dim = load_dim
self.sweeps_num = sweeps_num self.sweeps_num = sweeps_num
self.use_dim = use_dim self.use_dim = use_dim
...@@ -167,7 +173,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -167,7 +173,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
self.remove_close = remove_close self.remove_close = remove_close
self.test_mode = test_mode self.test_mode = test_mode
def _load_points(self, pts_filename): def _load_points(self, pts_filename: str) -> np.ndarray:
"""Private function to load point clouds data. """Private function to load point clouds data.
Args: Args:
...@@ -189,7 +195,11 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -189,7 +195,11 @@ class LoadPointsFromMultiSweeps(BaseTransform):
points = np.fromfile(pts_filename, dtype=np.float32) points = np.fromfile(pts_filename, dtype=np.float32)
return points return points
def _remove_close(self, points, radius=1.0): def _remove_close(
self,
points: Union[np.ndarray, BasePoints],
radius: float = 1.0
) -> Union[np.ndarray, BasePoints]:
"""Removes point too close within a certain radius from origin. """Removes point too close within a certain radius from origin.
Args: Args:
...@@ -198,7 +208,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -198,7 +208,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to 1.0. Defaults to 1.0.
Returns: Returns:
np.ndarray: Points after removing. np.ndarray | :obj:`BasePoints`: Points after removing.
""" """
if isinstance(points, np.ndarray): if isinstance(points, np.ndarray):
points_numpy = points points_numpy = points
...@@ -211,7 +221,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -211,7 +221,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
not_close = np.logical_not(np.logical_and(x_filt, y_filt)) not_close = np.logical_not(np.logical_and(x_filt, y_filt))
return points[not_close] return points[not_close]
def transform(self, results): def transform(self, results: dict) -> dict:
"""Call function to load multi-sweep point clouds from files. """Call function to load multi-sweep point clouds from files.
Args: Args:
...@@ -220,7 +230,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -220,7 +230,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Returns: Returns:
dict: The result dict containing the multi-sweep points data. dict: The result dict containing the multi-sweep points data.
Added key and value are described below. Updated key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
cloud arrays. cloud arrays.
...@@ -290,7 +300,7 @@ class PointSegClassMapping(BaseTransform): ...@@ -290,7 +300,7 @@ class PointSegClassMapping(BaseTransform):
others as len(valid_cat_ids). others as len(valid_cat_ids).
""" """
def transform(self, results: dict) -> None: def transform(self, results: dict) -> dict:
"""Call function to map original semantic class to valid category ids. """Call function to map original semantic class to valid category ids.
Args: Args:
...@@ -322,8 +332,6 @@ class PointSegClassMapping(BaseTransform): ...@@ -322,8 +332,6 @@ class PointSegClassMapping(BaseTransform):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(valid_cat_ids={self.valid_cat_ids}, '
repr_str += f'max_cat_id={self.max_cat_id})'
return repr_str return repr_str
...@@ -385,13 +393,14 @@ class LoadPointsFromFile(BaseTransform): ...@@ -385,13 +393,14 @@ class LoadPointsFromFile(BaseTransform):
Args: Args:
coord_type (str): The type of coordinates of points cloud. coord_type (str): The type of coordinates of points cloud.
Available options includes: Available options includes:
- 'LIDAR': Points in LiDAR coordinates. - 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset. - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates. - 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points. load_dim (int, optional): The dimension of the loaded points.
Defaults to 6. Defaults to 6.
use_dim (list[int], optional): Which dimensions of the points to use. use_dim (list[int] | int, optional): Which dimensions of the points
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 to use. Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension. or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height. shift_height (bool, optional): Whether to use shifted height.
Defaults to False. Defaults to False.
...@@ -399,7 +408,7 @@ class LoadPointsFromFile(BaseTransform): ...@@ -399,7 +408,7 @@ class LoadPointsFromFile(BaseTransform):
Defaults to False. Defaults to False.
file_client_args (dict, optional): Config dict of file clients, file_client_args (dict, optional): Config dict of file clients,
refer to refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
""" """
...@@ -407,7 +416,7 @@ class LoadPointsFromFile(BaseTransform): ...@@ -407,7 +416,7 @@ class LoadPointsFromFile(BaseTransform):
self, self,
coord_type: str, coord_type: str,
load_dim: int = 6, load_dim: int = 6,
use_dim: list = [0, 1, 2], use_dim: Union[int, List[int]] = [0, 1, 2],
shift_height: bool = False, shift_height: bool = False,
use_color: bool = False, use_color: bool = False,
file_client_args: dict = dict(backend='disk') file_client_args: dict = dict(backend='disk')
...@@ -523,6 +532,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -523,6 +532,7 @@ class LoadAnnotations3D(LoadAnnotations):
Required Keys: Required Keys:
- ann_info (dict) - ann_info (dict)
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` | - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
:obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`): :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
3D ground truth bboxes. Only when `with_bbox_3d` is True 3D ground truth bboxes. Only when `with_bbox_3d` is True
...@@ -592,7 +602,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -592,7 +602,7 @@ class LoadAnnotations3D(LoadAnnotations):
seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks. seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
Defaults to int64. Defaults to int64.
file_client_args (dict): Config dict of file clients, refer to file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. for more details.
""" """
......
...@@ -16,7 +16,7 @@ class MultiScaleFlipAug3D(BaseTransform): ...@@ -16,7 +16,7 @@ class MultiScaleFlipAug3D(BaseTransform):
Args: Args:
transforms (list[dict]): Transforms to apply in each augmentation. transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]: Images scales for resizing. img_scale (tuple | list[tuple]): Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for pts_scale_ratio (float | list[float]): Points scale ratios for
resizing. resizing.
flip (bool, optional): Whether apply flip augmentation. flip (bool, optional): Whether apply flip augmentation.
...@@ -25,11 +25,11 @@ class MultiScaleFlipAug3D(BaseTransform): ...@@ -25,11 +25,11 @@ class MultiScaleFlipAug3D(BaseTransform):
directions for images, options are "horizontal" and "vertical". directions for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when ``flip == False``. be applied. It has no effect when ``flip == False``.
Defaults to "horizontal". Defaults to 'horizontal'.
pcd_horizontal_flip (bool, optional): Whether apply horizontal pcd_horizontal_flip (bool, optional): Whether to apply horizontal
flip augmentation to point cloud. Defaults to True. flip augmentation to point cloud. Defaults to True.
Note that it works only when 'flip' is turned on. Note that it works only when 'flip' is turned on.
pcd_vertical_flip (bool, optional): Whether apply vertical flip pcd_vertical_flip (bool, optional): Whether to apply vertical flip
augmentation to point cloud. Defaults to True. augmentation to point cloud. Defaults to True.
Note that it works only when 'flip' is turned on. Note that it works only when 'flip' is turned on.
""" """
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
...@@ -76,7 +76,6 @@ class RandomFlip3D(RandomFlip): ...@@ -76,7 +76,6 @@ class RandomFlip3D(RandomFlip):
otherwise it will be randomly decided by a ratio specified in the init otherwise it will be randomly decided by a ratio specified in the init
method. method.
Required Keys: Required Keys:
- points (np.float32) - points (np.float32)
...@@ -329,7 +328,7 @@ class ObjectSample(BaseTransform): ...@@ -329,7 +328,7 @@ class ObjectSample(BaseTransform):
def __init__(self, def __init__(self,
db_sampler: dict, db_sampler: dict,
sample_2d: bool = False, sample_2d: bool = False,
use_ground_plane: bool = False): use_ground_plane: bool = False) -> None:
self.sampler_cfg = db_sampler self.sampler_cfg = db_sampler
self.sample_2d = sample_2d self.sample_2d = sample_2d
if 'type' not in db_sampler.keys(): if 'type' not in db_sampler.keys():
...@@ -456,10 +455,10 @@ class ObjectNoise(BaseTransform): ...@@ -456,10 +455,10 @@ class ObjectNoise(BaseTransform):
""" """
def __init__(self, def __init__(self,
translation_std: list = [0.25, 0.25, 0.25], translation_std: List[float] = [0.25, 0.25, 0.25],
global_rot_range: list = [0.0, 0.0], global_rot_range: List[float] = [0.0, 0.0],
rot_range: list = [-0.15707963267, 0.15707963267], rot_range: List[float] = [-0.15707963267, 0.15707963267],
num_try: int = 100): num_try: int = 100) -> None:
self.translation_std = translation_std self.translation_std = translation_std
self.global_rot_range = global_rot_range self.global_rot_range = global_rot_range
self.rot_range = rot_range self.rot_range = rot_range
...@@ -522,7 +521,7 @@ class GlobalAlignment(BaseTransform): ...@@ -522,7 +521,7 @@ class GlobalAlignment(BaseTransform):
def __init__(self, rotation_axis: int) -> None: def __init__(self, rotation_axis: int) -> None:
self.rotation_axis = rotation_axis self.rotation_axis = rotation_axis
def _trans_points(self, results: Dict, trans_factor: np.ndarray) -> None: def _trans_points(self, results: dict, trans_factor: np.ndarray) -> None:
"""Private function to translate points. """Private function to translate points.
Args: Args:
...@@ -534,7 +533,7 @@ class GlobalAlignment(BaseTransform): ...@@ -534,7 +533,7 @@ class GlobalAlignment(BaseTransform):
""" """
results['points'].translate(trans_factor) results['points'].translate(trans_factor)
def _rot_points(self, results: Dict, rot_mat: np.ndarray) -> None: def _rot_points(self, results: dict, rot_mat: np.ndarray) -> None:
"""Private function to rotate bounding boxes and points. """Private function to rotate bounding boxes and points.
Args: Args:
...@@ -560,7 +559,7 @@ class GlobalAlignment(BaseTransform): ...@@ -560,7 +559,7 @@ class GlobalAlignment(BaseTransform):
is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all() is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all()
assert is_valid, f'invalid rotation matrix {rot_mat}' assert is_valid, f'invalid rotation matrix {rot_mat}'
def transform(self, results: Dict) -> Dict: def transform(self, results: dict) -> dict:
"""Call function to shuffle points. """Call function to shuffle points.
Args: Args:
...@@ -586,6 +585,7 @@ class GlobalAlignment(BaseTransform): ...@@ -586,6 +585,7 @@ class GlobalAlignment(BaseTransform):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(rotation_axis={self.rotation_axis})' repr_str += f'(rotation_axis={self.rotation_axis})'
return repr_str return repr_str
...@@ -804,6 +804,7 @@ class PointShuffle(BaseTransform): ...@@ -804,6 +804,7 @@ class PointShuffle(BaseTransform):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
return self.__class__.__name__ return self.__class__.__name__
...@@ -823,7 +824,7 @@ class ObjectRangeFilter(BaseTransform): ...@@ -823,7 +824,7 @@ class ObjectRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range. point_cloud_range (list[float]): Point cloud range.
""" """
def __init__(self, point_cloud_range: list): def __init__(self, point_cloud_range: List[float]):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict: def transform(self, input_dict: dict) -> dict:
...@@ -885,7 +886,7 @@ class PointsRangeFilter(BaseTransform): ...@@ -885,7 +886,7 @@ class PointsRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range. point_cloud_range (list[float]): Point cloud range.
""" """
def __init__(self, point_cloud_range: list): def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict: def transform(self, input_dict: dict) -> dict:
...@@ -938,7 +939,7 @@ class ObjectNameFilter(BaseTransform): ...@@ -938,7 +939,7 @@ class ObjectNameFilter(BaseTransform):
classes (list[str]): List of class names to be kept for training. classes (list[str]): List of class names to be kept for training.
""" """
def __init__(self, classes: list): def __init__(self, classes: List[str]) -> None:
self.classes = classes self.classes = classes
self.labels = list(range(len(self.classes))) self.labels = list(range(len(self.classes)))
...@@ -996,34 +997,38 @@ class PointSample(BaseTransform): ...@@ -996,34 +997,38 @@ class PointSample(BaseTransform):
def __init__(self, def __init__(self,
num_points: int, num_points: int,
sample_range: float = None, sample_range: Optional[float] = None,
replace: bool = False): replace: bool = False) -> None:
self.num_points = num_points self.num_points = num_points
self.sample_range = sample_range self.sample_range = sample_range
self.replace = replace self.replace = replace
def _points_random_sampling(self, def _points_random_sampling(
points, self,
num_samples, points: BasePoints,
sample_range=None, num_samples: int,
replace=False, sample_range: Optional[float] = None,
return_choices=False): replace: bool = False,
return_choices: bool = False
) -> Union[Tuple[BasePoints, np.ndarray], BasePoints]:
"""Points random sampling. """Points random sampling.
Sample points to a certain number. Sample points to a certain number.
Args: Args:
points (np.ndarray | :obj:`BasePoints`): 3D Points. points (:obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled. num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the sample_range (float, optional): Indicating the range where the
points will be sampled. Defaults to None. points will be sampled. Defaults to None.
replace (bool, optional): Sampling with or without replacement. replace (bool, optional): Sampling with or without replacement.
Defaults to None. Defaults to False.
return_choices (bool, optional): Whether return choice. return_choices (bool, optional): Whether return choice.
Defaults to False. Defaults to False.
Returns: Returns:
tuple[np.ndarray] | np.ndarray: tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
- points (np.ndarray | :obj:`BasePoints`): 3D Points.
- points (:obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples. - choices (np.ndarray, optional): The generated random samples.
""" """
if not replace: if not replace:
...@@ -1031,7 +1036,7 @@ class PointSample(BaseTransform): ...@@ -1031,7 +1036,7 @@ class PointSample(BaseTransform):
point_range = range(len(points)) point_range = range(len(points))
if sample_range is not None and not replace: if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples # Only sampling the near points when len(points) >= num_samples
dist = np.linalg.norm(points.tensor, axis=1) dist = np.linalg.norm(points.coord.numpy(), axis=1)
far_inds = np.where(dist >= sample_range)[0] far_inds = np.where(dist >= sample_range)[0]
near_inds = np.where(dist < sample_range)[0] near_inds = np.where(dist < sample_range)[0]
# in case there are too many far points # in case there are too many far points
...@@ -1055,6 +1060,7 @@ class PointSample(BaseTransform): ...@@ -1055,6 +1060,7 @@ class PointSample(BaseTransform):
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
...@@ -1214,8 +1220,11 @@ class IndoorPatchPointSample(BaseTransform): ...@@ -1214,8 +1220,11 @@ class IndoorPatchPointSample(BaseTransform):
return points return points
def _patch_points_sampling(self, points: BasePoints, def _patch_points_sampling(
sem_mask: np.ndarray) -> BasePoints: self,
points: BasePoints,
sem_mask: np.ndarray
) -> Tuple[BasePoints, np.ndarray]:
"""Patch points sampling. """Patch points sampling.
First sample a valid patch. First sample a valid patch.
...@@ -1226,7 +1235,7 @@ class IndoorPatchPointSample(BaseTransform): ...@@ -1226,7 +1235,7 @@ class IndoorPatchPointSample(BaseTransform):
sem_mask (np.ndarray): semantic segmentation mask for input points. sem_mask (np.ndarray): semantic segmentation mask for input points.
Returns: Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`: tuple[:obj:`BasePoints`, np.ndarray]:
- points (:obj:`BasePoints`): 3D Points. - points (:obj:`BasePoints`): 3D Points.
- choices (np.ndarray): The generated random samples. - choices (np.ndarray): The generated random samples.
...@@ -1433,7 +1442,7 @@ class BackgroundPointsFilter(BaseTransform): ...@@ -1433,7 +1442,7 @@ class BackgroundPointsFilter(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class VoxelBasedPointSampler(object): class VoxelBasedPointSampler(BaseTransform):
"""Voxel based point sampler. """Voxel based point sampler.
Apply voxel sampling to multiple sweep points. Apply voxel sampling to multiple sweep points.
...@@ -1445,7 +1454,10 @@ class VoxelBasedPointSampler(object): ...@@ -1445,7 +1454,10 @@ class VoxelBasedPointSampler(object):
for input points. for input points.
""" """
def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3): def __init__(self,
cur_sweep_cfg: dict,
prev_sweep_cfg: Optional[dict] = None,
time_dim: int = 3) -> None:
self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg) self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg)
self.cur_voxel_num = self.cur_voxel_generator._max_voxels self.cur_voxel_num = self.cur_voxel_generator._max_voxels
self.time_dim = time_dim self.time_dim = time_dim
...@@ -1458,7 +1470,10 @@ class VoxelBasedPointSampler(object): ...@@ -1458,7 +1470,10 @@ class VoxelBasedPointSampler(object):
self.prev_voxel_generator = None self.prev_voxel_generator = None
self.prev_voxel_num = 0 self.prev_voxel_num = 0
def _sample_points(self, points, sampler, point_dim): def _sample_points(self,
points: np.ndarray,
sampler: VoxelGenerator,
point_dim: int) -> np.ndarray:
"""Sample points for each points subset. """Sample points for each points subset.
Args: Args:
...@@ -1484,7 +1499,7 @@ class VoxelBasedPointSampler(object): ...@@ -1484,7 +1499,7 @@ class VoxelBasedPointSampler(object):
return sample_points return sample_points
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to sample points from multiple sweeps. """Call function to sample points from multiple sweeps.
Args: Args:
...@@ -1766,6 +1781,7 @@ class AffineResize(BaseTransform): ...@@ -1766,6 +1781,7 @@ class AffineResize(BaseTransform):
return ref_point3 return ref_point3
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, ' repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'down_ratio={self.down_ratio}) ' repr_str += f'down_ratio={self.down_ratio}) '
...@@ -1786,7 +1802,7 @@ class RandomShiftScale(BaseTransform): ...@@ -1786,7 +1802,7 @@ class RandomShiftScale(BaseTransform):
aug_prob (float): The shifting and scaling probability. aug_prob (float): The shifting and scaling probability.
""" """
def __init__(self, shift_scale: Tuple[float], aug_prob: float): def __init__(self, shift_scale: Tuple[float], aug_prob: float) -> None:
self.shift_scale = shift_scale self.shift_scale = shift_scale
self.aug_prob = aug_prob self.aug_prob = aug_prob
...@@ -1825,6 +1841,7 @@ class RandomShiftScale(BaseTransform): ...@@ -1825,6 +1841,7 @@ class RandomShiftScale(BaseTransform):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(shift_scale={self.shift_scale}, ' repr_str += f'(shift_scale={self.shift_scale}, '
repr_str += f'aug_prob={self.aug_prob}) ' repr_str += f'aug_prob={self.aug_prob}) '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment