Unverified Commit 5412046b authored by Xiangxu-0103's avatar Xiangxu-0103 Committed by GitHub
Browse files

[Enhance]: Add typehints for dataset transforms and fix potential bug for `PointSample` (#1875)

* update dataset transforms

* update dbsampler docstring and add typehints

* add type hints and fix potential point sample bug

* fix lint

* fix

* fix
parent d8c9bc66
......@@ -32,7 +32,7 @@ class Compose:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
dict: Transformed data.
"""
for t in self.transforms:
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import warnings
from typing import List, Optional
import mmengine
import numpy as np
......@@ -16,18 +16,19 @@ class BatchSampler:
Args:
sample_list (list[dict]): List of samples.
name (str, optional): The category of samples. Default: None.
epoch (int, optional): Sampling epoch. Default: None.
shuffle (bool, optional): Whether to shuffle indices. Default: False.
drop_reminder (bool, optional): Drop reminder. Default: False.
name (str, optional): The category of samples. Defaults to None.
epoch (int, optional): Sampling epoch. Defaults to None.
shuffle (bool, optional): Whether to shuffle indices.
Defaults to False.
drop_reminder (bool, optional): Drop reminder. Defaults to False.
"""
def __init__(self,
sampled_list,
name=None,
epoch=None,
shuffle=True,
drop_reminder=False):
sampled_list: List[dict],
name: Optional[str] = None,
epoch: Optional[int] = None,
shuffle: bool = True,
drop_reminder: bool = False) -> None:
self._sampled_list = sampled_list
self._indices = np.arange(len(sampled_list))
if shuffle:
......@@ -40,7 +41,7 @@ class BatchSampler:
self._epoch_counter = 0
self._drop_reminder = drop_reminder
def _sample(self, num):
def _sample(self, num: int) -> List[int]:
"""Sample specific number of ground truths and return indices.
Args:
......@@ -57,7 +58,7 @@ class BatchSampler:
self._idx += num
return ret
def _reset(self):
def _reset(self) -> None:
"""Reset the index of batchsampler to zero."""
assert self._name is not None
# print("reset", self._name)
......@@ -65,7 +66,7 @@ class BatchSampler:
np.random.shuffle(self._indices)
self._idx = 0
def sample(self, num):
def sample(self, num: int) -> List[dict]:
"""Sample specific number of ground truths.
Args:
......@@ -88,24 +89,28 @@ class DataBaseSampler(object):
rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Default: None.
points_loader(dict, optional): Config of points loader. Default:
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
classes (list[str], optional): List of classes. Defaults to None.
points_loader(dict, optional): Config of points loader. Defaults to
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""
def __init__(self,
info_path,
data_root,
rate,
prepare,
sample_groups,
classes=None,
points_loader=dict(
info_path: str,
data_root: str,
rate: float,
prepare: dict,
sample_groups: dict,
classes: Optional[List[str]] = None,
points_loader: dict = dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=[0, 1, 2, 3]),
file_client_args=dict(backend='disk')):
file_client_args: dict = dict(backend='disk')) -> None:
super().__init__()
self.data_root = data_root
self.info_path = info_path
......@@ -118,18 +123,9 @@ class DataBaseSampler(object):
self.file_client = mmengine.FileClient(**file_client_args)
# load data base infos
if hasattr(self.file_client, 'get_local_path'):
with self.file_client.get_local_path(info_path) as local_path:
# loading data from a file-like object needs file format
db_infos = mmengine.load(
open(local_path, 'rb'), file_format='pkl')
else:
warnings.warn(
'The used MMCV version does not have get_local_path. '
f'We treat the {info_path} as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.')
db_infos = mmengine.load(info_path)
with self.file_client.get_local_path(info_path) as local_path:
# loading data from a file-like object needs file format
db_infos = mmengine.load(open(local_path, 'rb'), file_format='pkl')
# filter database infos
from mmengine.logging import MMLogger
......@@ -163,7 +159,7 @@ class DataBaseSampler(object):
# TODO: No group_sampling currently
@staticmethod
def filter_by_difficulty(db_infos, removed_difficulty):
def filter_by_difficulty(db_infos: dict, removed_difficulty: list) -> dict:
"""Filter ground truths by difficulties.
Args:
......@@ -182,7 +178,7 @@ class DataBaseSampler(object):
return new_db_infos
@staticmethod
def filter_by_min_points(db_infos, min_gt_points_dict):
def filter_by_min_points(db_infos: dict, min_gt_points_dict: dict) -> dict:
"""Filter ground truths by number of points in the bbox.
Args:
......@@ -203,12 +199,19 @@ class DataBaseSampler(object):
db_infos[name] = filtered_infos
return db_infos
def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None):
def sample_all(self,
gt_bboxes: np.ndarray,
gt_labels: np.ndarray,
img: Optional[np.ndarray] = None,
ground_plane: Optional[np.ndarray] = None) -> dict:
"""Sampling all categories of bboxes.
Args:
gt_bboxes (np.ndarray): Ground truth bounding boxes.
gt_labels (np.ndarray): Ground truth labels of boxes.
img (np.ndarray, optional): Image array. Defaults to None.
ground_plane (np.ndarray, optional): Ground plane information.
Defaults to None.
Returns:
dict: Dict of sampled 'pseudo ground truths'.
......@@ -301,7 +304,10 @@ class DataBaseSampler(object):
return ret
def sample_class_v2(self, name, num, gt_bboxes):
def sample_class_v2(self,
name: str,
num: int,
gt_bboxes: np.ndarray) -> List[dict]:
"""Sampling specific categories of bounding boxes.
Args:
......
......@@ -63,16 +63,16 @@ class Pack3DDetInputs(BaseTransform):
def __init__(
self,
keys: dict,
meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
'num_pts_feats', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')):
keys: tuple,
meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape',
'scale_factor', 'flip', 'pcd_horizontal_flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'img_norm_cfg', 'num_pts_feats', 'pcd_trans',
'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')) -> None:
self.keys = keys
self.meta_keys = meta_keys
......@@ -99,7 +99,7 @@ class Pack3DDetInputs(BaseTransform):
- img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info of
the sample.
the sample.
"""
# augtest
if isinstance(results, list):
......@@ -116,7 +116,7 @@ class Pack3DDetInputs(BaseTransform):
else:
raise NotImplementedError
def pack_single_results(self, results):
def pack_single_results(self, results: dict) -> dict:
"""Method to pack the single input data. when the value in this dict is
a list, it usually is in Augmentations Testing.
......@@ -132,7 +132,7 @@ class Pack3DDetInputs(BaseTransform):
- points
- img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info
- 'data_samples' (:obj:`Det3DDataSample`): The annotation info
of the sample.
"""
# Format 3D data
......@@ -220,6 +220,7 @@ class Pack3DDetInputs(BaseTransform):
return packed_results
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(keys={self.keys})'
repr_str += f'(meta_keys={self.meta_keys})'
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from typing import List, Union
import mmcv
import mmengine
......@@ -13,7 +13,7 @@ from mmdet.datasets.transforms import LoadAnnotations
@TRANSFORMS.register_module()
class LoadMultiViewImageFromFiles(object):
class LoadMultiViewImageFromFiles(BaseTransform):
"""Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames.
......@@ -25,11 +25,15 @@ class LoadMultiViewImageFromFiles(object):
Defaults to 'unchanged'.
"""
def __init__(self, to_float32=False, color_type='unchanged'):
def __init__(
self,
to_float32: bool = False,
color_type: str = 'unchanged'
) -> None:
self.to_float32 = to_float32
self.color_type = color_type
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to load multi-view image from files.
Args:
......@@ -139,7 +143,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to [0, 1, 2, 4].
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
sweeps is empty. Defaults to False.
......@@ -150,14 +154,16 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to False.
"""
def __init__(self,
sweeps_num=10,
load_dim=5,
use_dim=[0, 1, 2, 4],
file_client_args=dict(backend='disk'),
pad_empty_sweeps=False,
remove_close=False,
test_mode=False):
def __init__(
self,
sweeps_num: int = 10,
load_dim: int = 5,
use_dim: List[int] = [0, 1, 2, 4],
file_client_args: dict = dict(backend='disk'),
pad_empty_sweeps: bool = False,
remove_close: bool = False,
test_mode: bool = False
) -> None:
self.load_dim = load_dim
self.sweeps_num = sweeps_num
self.use_dim = use_dim
......@@ -167,7 +173,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
self.remove_close = remove_close
self.test_mode = test_mode
def _load_points(self, pts_filename):
def _load_points(self, pts_filename: str) -> np.ndarray:
"""Private function to load point clouds data.
Args:
......@@ -189,7 +195,11 @@ class LoadPointsFromMultiSweeps(BaseTransform):
points = np.fromfile(pts_filename, dtype=np.float32)
return points
def _remove_close(self, points, radius=1.0):
def _remove_close(
self,
points: Union[np.ndarray, BasePoints],
radius: float = 1.0
) -> Union[np.ndarray, BasePoints]:
"""Removes point too close within a certain radius from origin.
Args:
......@@ -198,7 +208,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to 1.0.
Returns:
np.ndarray: Points after removing.
np.ndarray | :obj:`BasePoints`: Points after removing.
"""
if isinstance(points, np.ndarray):
points_numpy = points
......@@ -211,7 +221,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
not_close = np.logical_not(np.logical_and(x_filt, y_filt))
return points[not_close]
def transform(self, results):
def transform(self, results: dict) -> dict:
"""Call function to load multi-sweep point clouds from files.
Args:
......@@ -220,7 +230,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Returns:
dict: The result dict containing the multi-sweep points data.
Added key and value are described below.
Updated key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
cloud arrays.
......@@ -290,7 +300,7 @@ class PointSegClassMapping(BaseTransform):
others as len(valid_cat_ids).
"""
def transform(self, results: dict) -> None:
def transform(self, results: dict) -> dict:
"""Call function to map original semantic class to valid category ids.
Args:
......@@ -322,8 +332,6 @@ class PointSegClassMapping(BaseTransform):
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(valid_cat_ids={self.valid_cat_ids}, '
repr_str += f'max_cat_id={self.max_cat_id})'
return repr_str
......@@ -385,13 +393,14 @@ class LoadPointsFromFile(BaseTransform):
Args:
coord_type (str): The type of coordinates of points cloud.
Available options includes:
- 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points.
Defaults to 6.
use_dim (list[int], optional): Which dimensions of the points to use.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
use_dim (list[int] | int, optional): Which dimensions of the points
to use. Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height.
Defaults to False.
......@@ -399,7 +408,7 @@ class LoadPointsFromFile(BaseTransform):
Defaults to False.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""
......@@ -407,7 +416,7 @@ class LoadPointsFromFile(BaseTransform):
self,
coord_type: str,
load_dim: int = 6,
use_dim: list = [0, 1, 2],
use_dim: Union[int, List[int]] = [0, 1, 2],
shift_height: bool = False,
use_color: bool = False,
file_client_args: dict = dict(backend='disk')
......@@ -523,6 +532,7 @@ class LoadAnnotations3D(LoadAnnotations):
Required Keys:
- ann_info (dict)
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
:obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
3D ground truth bboxes. Only when `with_bbox_3d` is True
......@@ -592,7 +602,7 @@ class LoadAnnotations3D(LoadAnnotations):
seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
Defaults to int64.
file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details.
"""
......
......@@ -16,7 +16,7 @@ class MultiScaleFlipAug3D(BaseTransform):
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]: Images scales for resizing.
img_scale (tuple | list[tuple]): Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for
resizing.
flip (bool, optional): Whether apply flip augmentation.
......@@ -25,11 +25,11 @@ class MultiScaleFlipAug3D(BaseTransform):
directions for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when ``flip == False``.
Defaults to "horizontal".
pcd_horizontal_flip (bool, optional): Whether apply horizontal
Defaults to 'horizontal'.
pcd_horizontal_flip (bool, optional): Whether to apply horizontal
flip augmentation to point cloud. Defaults to True.
Note that it works only when 'flip' is turned on.
pcd_vertical_flip (bool, optional): Whether apply vertical flip
pcd_vertical_flip (bool, optional): Whether to apply vertical flip
augmentation to point cloud. Defaults to True.
Note that it works only when 'flip' is turned on.
"""
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import cv2
import numpy as np
......@@ -76,7 +76,6 @@ class RandomFlip3D(RandomFlip):
otherwise it will be randomly decided by a ratio specified in the init
method.
Required Keys:
- points (np.float32)
......@@ -329,7 +328,7 @@ class ObjectSample(BaseTransform):
def __init__(self,
db_sampler: dict,
sample_2d: bool = False,
use_ground_plane: bool = False):
use_ground_plane: bool = False) -> None:
self.sampler_cfg = db_sampler
self.sample_2d = sample_2d
if 'type' not in db_sampler.keys():
......@@ -456,10 +455,10 @@ class ObjectNoise(BaseTransform):
"""
def __init__(self,
translation_std: list = [0.25, 0.25, 0.25],
global_rot_range: list = [0.0, 0.0],
rot_range: list = [-0.15707963267, 0.15707963267],
num_try: int = 100):
translation_std: List[float] = [0.25, 0.25, 0.25],
global_rot_range: List[float] = [0.0, 0.0],
rot_range: List[float] = [-0.15707963267, 0.15707963267],
num_try: int = 100) -> None:
self.translation_std = translation_std
self.global_rot_range = global_rot_range
self.rot_range = rot_range
......@@ -522,7 +521,7 @@ class GlobalAlignment(BaseTransform):
def __init__(self, rotation_axis: int) -> None:
self.rotation_axis = rotation_axis
def _trans_points(self, results: Dict, trans_factor: np.ndarray) -> None:
def _trans_points(self, results: dict, trans_factor: np.ndarray) -> None:
"""Private function to translate points.
Args:
......@@ -534,7 +533,7 @@ class GlobalAlignment(BaseTransform):
"""
results['points'].translate(trans_factor)
def _rot_points(self, results: Dict, rot_mat: np.ndarray) -> None:
def _rot_points(self, results: dict, rot_mat: np.ndarray) -> None:
"""Private function to rotate bounding boxes and points.
Args:
......@@ -560,7 +559,7 @@ class GlobalAlignment(BaseTransform):
is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all()
assert is_valid, f'invalid rotation matrix {rot_mat}'
def transform(self, results: Dict) -> Dict:
def transform(self, results: dict) -> dict:
"""Call function to shuffle points.
Args:
......@@ -586,6 +585,7 @@ class GlobalAlignment(BaseTransform):
return results
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(rotation_axis={self.rotation_axis})'
return repr_str
......@@ -804,6 +804,7 @@ class PointShuffle(BaseTransform):
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
return self.__class__.__name__
......@@ -823,7 +824,7 @@ class ObjectRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range: list):
def __init__(self, point_cloud_range: List[float]):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict:
......@@ -885,7 +886,7 @@ class PointsRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range: list):
def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict:
......@@ -938,7 +939,7 @@ class ObjectNameFilter(BaseTransform):
classes (list[str]): List of class names to be kept for training.
"""
def __init__(self, classes: list):
def __init__(self, classes: List[str]) -> None:
self.classes = classes
self.labels = list(range(len(self.classes)))
......@@ -996,34 +997,38 @@ class PointSample(BaseTransform):
def __init__(self,
num_points: int,
sample_range: float = None,
replace: bool = False):
sample_range: Optional[float] = None,
replace: bool = False) -> None:
self.num_points = num_points
self.sample_range = sample_range
self.replace = replace
def _points_random_sampling(self,
points,
num_samples,
sample_range=None,
replace=False,
return_choices=False):
def _points_random_sampling(
self,
points: BasePoints,
num_samples: int,
sample_range: Optional[float] = None,
replace: bool = False,
return_choices: bool = False
) -> Union[Tuple[BasePoints, np.ndarray], BasePoints]:
"""Points random sampling.
Sample points to a certain number.
Args:
points (np.ndarray | :obj:`BasePoints`): 3D Points.
points (:obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the
points will be sampled. Defaults to None.
replace (bool, optional): Sampling with or without replacement.
Defaults to None.
Defaults to False.
return_choices (bool, optional): Whether return choice.
Defaults to False.
Returns:
tuple[np.ndarray] | np.ndarray:
- points (np.ndarray | :obj:`BasePoints`): 3D Points.
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
- points (:obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples.
"""
if not replace:
......@@ -1031,7 +1036,7 @@ class PointSample(BaseTransform):
point_range = range(len(points))
if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples
dist = np.linalg.norm(points.tensor, axis=1)
dist = np.linalg.norm(points.coord.numpy(), axis=1)
far_inds = np.where(dist >= sample_range)[0]
near_inds = np.where(dist < sample_range)[0]
# in case there are too many far points
......@@ -1055,6 +1060,7 @@ class PointSample(BaseTransform):
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
......@@ -1214,8 +1220,11 @@ class IndoorPatchPointSample(BaseTransform):
return points
def _patch_points_sampling(self, points: BasePoints,
sem_mask: np.ndarray) -> BasePoints:
def _patch_points_sampling(
self,
points: BasePoints,
sem_mask: np.ndarray
) -> Tuple[BasePoints, np.ndarray]:
"""Patch points sampling.
First sample a valid patch.
......@@ -1226,7 +1235,7 @@ class IndoorPatchPointSample(BaseTransform):
sem_mask (np.ndarray): semantic segmentation mask for input points.
Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
tuple[:obj:`BasePoints`, np.ndarray]:
- points (:obj:`BasePoints`): 3D Points.
- choices (np.ndarray): The generated random samples.
......@@ -1433,7 +1442,7 @@ class BackgroundPointsFilter(BaseTransform):
@TRANSFORMS.register_module()
class VoxelBasedPointSampler(object):
class VoxelBasedPointSampler(BaseTransform):
"""Voxel based point sampler.
Apply voxel sampling to multiple sweep points.
......@@ -1445,7 +1454,10 @@ class VoxelBasedPointSampler(object):
for input points.
"""
def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3):
def __init__(self,
cur_sweep_cfg: dict,
prev_sweep_cfg: Optional[dict] = None,
time_dim: int = 3) -> None:
self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg)
self.cur_voxel_num = self.cur_voxel_generator._max_voxels
self.time_dim = time_dim
......@@ -1458,7 +1470,10 @@ class VoxelBasedPointSampler(object):
self.prev_voxel_generator = None
self.prev_voxel_num = 0
def _sample_points(self, points, sampler, point_dim):
def _sample_points(self,
points: np.ndarray,
sampler: VoxelGenerator,
point_dim: int) -> np.ndarray:
"""Sample points for each points subset.
Args:
......@@ -1484,7 +1499,7 @@ class VoxelBasedPointSampler(object):
return sample_points
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to sample points from multiple sweeps.
Args:
......@@ -1766,6 +1781,7 @@ class AffineResize(BaseTransform):
return ref_point3
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'down_ratio={self.down_ratio}) '
......@@ -1786,7 +1802,7 @@ class RandomShiftScale(BaseTransform):
aug_prob (float): The shifting and scaling probability.
"""
def __init__(self, shift_scale: Tuple[float], aug_prob: float):
def __init__(self, shift_scale: Tuple[float], aug_prob: float) -> None:
self.shift_scale = shift_scale
self.aug_prob = aug_prob
......@@ -1825,6 +1841,7 @@ class RandomShiftScale(BaseTransform):
return results
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(shift_scale={self.shift_scale}, '
repr_str += f'aug_prob={self.aug_prob}) '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment