"docs/vscode:/vscode.git/clone" did not exist on "a70dd2998b3403da3d1ce9f3d5a27bd42db2d5c7"
Commit 53435c62 authored by Yezhen Cong's avatar Yezhen Cong Committed by Tai-Wang
Browse files

[Refactor] Refactor code structure and docstrings (#803)

* refactor points_in_boxes

* Merge same functions of three boxes

* More docstring fixes and unify x/y/z size

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Remove None in function param type

* Fix unittest

* Add comments for NMS functions

* Merge methods of Points

* Add unittest

* Add optional and default value

* Fix box conversion and add unittest

* Fix comments

* Add unit test

* Indent

* Fix CI

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Add unit test for box bev

* More unit tests and refine docstrings in box_np_ops

* Fix comment

* Add deprecation warning
parent 4f36084f
...@@ -129,7 +129,7 @@ class LyftDataset(Custom3DDataset): ...@@ -129,7 +129,7 @@ class LyftDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): sample index - sample_idx (str): sample index
...@@ -137,7 +137,7 @@ class LyftDataset(Custom3DDataset): ...@@ -137,7 +137,7 @@ class LyftDataset(Custom3DDataset):
- sweeps (list[dict]): infos of sweeps - sweeps (list[dict]): infos of sweeps
- timestamp (float): sample timestamp - timestamp (float): sample timestamp
- img_filename (str, optional): image filename - img_filename (str, optional): image filename
- lidar2img (list[np.ndarray], optional): transformations \ - lidar2img (list[np.ndarray], optional): transformations
from lidar to different cameras from lidar to different cameras
- ann_info (dict): annotation info - ann_info (dict): annotation info
""" """
...@@ -190,7 +190,7 @@ class LyftDataset(Custom3DDataset): ...@@ -190,7 +190,7 @@ class LyftDataset(Custom3DDataset):
Returns: Returns:
dict: Annotation information consists of the following keys: dict: Annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes. 3D ground truth bboxes.
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- gt_names (list[str]): Class names of ground truths. - gt_names (list[str]): Class names of ground truths.
...@@ -275,10 +275,11 @@ class LyftDataset(Custom3DDataset): ...@@ -275,10 +275,11 @@ class LyftDataset(Custom3DDataset):
Args: Args:
result_path (str): Path of the result file. result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'. metric (str, optional): Metric name used for evaluation.
result_name (str): Result name in the metric prefix. Default: 'bbox'.
result_name (str, optional): Result name in the metric prefix.
Default: 'pts_bbox'. Default: 'pts_bbox'.
Returns: Returns:
...@@ -312,18 +313,18 @@ class LyftDataset(Custom3DDataset): ...@@ -312,18 +313,18 @@ class LyftDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
csv_savepath (str | None): The path for saving csv files. csv_savepath (str): The path for saving csv files.
It includes the file path and the csv filename, It includes the file path and the csv filename,
e.g., "a/b/filename.csv". If not specified, e.g., "a/b/filename.csv". If not specified,
the result will not be converted to csv file. the result will not be converted to csv file.
Returns: Returns:
tuple: Returns (result_files, tmp_dir), where `result_files` is a \ tuple: Returns (result_files, tmp_dir), where `result_files` is a
dict containing the json filepaths, `tmp_dir` is the temporal \ dict containing the json filepaths, `tmp_dir` is the temporal
directory created for saving json files when \ directory created for saving json files when
`jsonfile_prefix` is not specified. `jsonfile_prefix` is not specified.
""" """
assert isinstance(results, list), 'results must be a list' assert isinstance(results, list), 'results must be a list'
...@@ -372,19 +373,22 @@ class LyftDataset(Custom3DDataset): ...@@ -372,19 +373,22 @@ class LyftDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing Default: 'bbox'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str, optional): The prefix of json files including
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
csv_savepath (str | None): The path for saving csv files. csv_savepath (str, optional): The path for saving csv files.
It includes the file path and the csv filename, It includes the file path and the csv filename,
e.g., "a/b/filename.csv". If not specified, e.g., "a/b/filename.csv". If not specified,
the result will not be converted to csv file. the result will not be converted to csv file.
show (bool): Whether to visualize. result_names (list[str], optional): Result names in the
metric prefix. Default: ['pts_bbox'].
show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
......
...@@ -48,8 +48,9 @@ class NuScenesDataset(Custom3DDataset): ...@@ -48,8 +48,9 @@ class NuScenesDataset(Custom3DDataset):
Defaults to False. Defaults to False.
eval_version (bool, optional): Configuration version of evaluation. eval_version (bool, optional): Configuration version of evaluation.
Defaults to 'detection_cvpr_2019'. Defaults to 'detection_cvpr_2019'.
use_valid_flag (bool): Whether to use `use_valid_flag` key in the info use_valid_flag (bool, optional): Whether to use `use_valid_flag` key
file as mask to filter gt_boxes and gt_names. Defaults to False. in the info file as mask to filter gt_boxes and gt_names.
Defaults to False.
""" """
NameMapping = { NameMapping = {
'movable_object.barrier': 'barrier', 'movable_object.barrier': 'barrier',
...@@ -196,7 +197,7 @@ class NuScenesDataset(Custom3DDataset): ...@@ -196,7 +197,7 @@ class NuScenesDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index. - sample_idx (str): Sample index.
...@@ -204,7 +205,7 @@ class NuScenesDataset(Custom3DDataset): ...@@ -204,7 +205,7 @@ class NuScenesDataset(Custom3DDataset):
- sweeps (list[dict]): Infos of sweeps. - sweeps (list[dict]): Infos of sweeps.
- timestamp (float): Sample timestamp. - timestamp (float): Sample timestamp.
- img_filename (str, optional): Image filename. - img_filename (str, optional): Image filename.
- lidar2img (list[np.ndarray], optional): Transformations \ - lidar2img (list[np.ndarray], optional): Transformations
from lidar to different cameras. from lidar to different cameras.
- ann_info (dict): Annotation info. - ann_info (dict): Annotation info.
""" """
...@@ -256,7 +257,7 @@ class NuScenesDataset(Custom3DDataset): ...@@ -256,7 +257,7 @@ class NuScenesDataset(Custom3DDataset):
Returns: Returns:
dict: Annotation information consists of the following keys: dict: Annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- gt_names (list[str]): Class names of ground truths. - gt_names (list[str]): Class names of ground truths.
...@@ -374,10 +375,11 @@ class NuScenesDataset(Custom3DDataset): ...@@ -374,10 +375,11 @@ class NuScenesDataset(Custom3DDataset):
Args: Args:
result_path (str): Path of the result file. result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'. metric (str, optional): Metric name used for evaluation.
result_name (str): Result name in the metric prefix. Default: 'bbox'.
result_name (str, optional): Result name in the metric prefix.
Default: 'pts_bbox'. Default: 'pts_bbox'.
Returns: Returns:
...@@ -427,14 +429,14 @@ class NuScenesDataset(Custom3DDataset): ...@@ -427,14 +429,14 @@ class NuScenesDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
Returns: Returns:
tuple: Returns (result_files, tmp_dir), where `result_files` is a \ tuple: Returns (result_files, tmp_dir), where `result_files` is a
dict containing the json filepaths, `tmp_dir` is the temporal \ dict containing the json filepaths, `tmp_dir` is the temporal
directory created for saving json files when \ directory created for saving json files when
`jsonfile_prefix` is not specified. `jsonfile_prefix` is not specified.
""" """
assert isinstance(results, list), 'results must be a list' assert isinstance(results, list), 'results must be a list'
...@@ -480,15 +482,16 @@ class NuScenesDataset(Custom3DDataset): ...@@ -480,15 +482,16 @@ class NuScenesDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing Default: 'bbox'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str, optional): The prefix of json files including
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
show (bool): Whether to visualize. show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
...@@ -624,7 +627,7 @@ def lidar_nusc_box_to_global(info, ...@@ -624,7 +627,7 @@ def lidar_nusc_box_to_global(info,
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation. classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object. eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version. eval_version (str, optional): Evaluation version.
Default: 'detection_cvpr_2019' Default: 'detection_cvpr_2019'
Returns: Returns:
......
...@@ -44,8 +44,9 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -44,8 +44,9 @@ class NuScenesMonoDataset(CocoDataset):
- 'Camera': Box in camera coordinates. - 'Camera': Box in camera coordinates.
eval_version (str, optional): Configuration version of evaluation. eval_version (str, optional): Configuration version of evaluation.
Defaults to 'detection_cvpr_2019'. Defaults to 'detection_cvpr_2019'.
use_valid_flag (bool): Whether to use `use_valid_flag` key in the info use_valid_flag (bool, optional): Whether to use `use_valid_flag` key
file as mask to filter gt_boxes and gt_names. Defaults to False. in the info file as mask to filter gt_boxes and gt_names.
Defaults to False.
version (str, optional): Dataset version. Defaults to 'v1.0-trainval'. version (str, optional): Dataset version. Defaults to 'v1.0-trainval'.
""" """
CLASSES = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', CLASSES = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
...@@ -140,8 +141,8 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -140,8 +141,8 @@ class NuScenesMonoDataset(CocoDataset):
ann_info (list[dict]): Annotation info of an image. ann_info (list[dict]): Annotation info of an image.
Returns: Returns:
dict: A dict containing the following keys: bboxes, labels, \ dict: A dict containing the following keys: bboxes, labels,
gt_bboxes_3d, gt_labels_3d, attr_labels, centers2d, \ gt_bboxes_3d, gt_labels_3d, attr_labels, centers2d,
depths, bboxes_ignore, masks, seg_map depths, bboxes_ignore, masks, seg_map
""" """
gt_bboxes = [] gt_bboxes = []
...@@ -394,10 +395,11 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -394,10 +395,11 @@ class NuScenesMonoDataset(CocoDataset):
Args: Args:
result_path (str): Path of the result file. result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'. metric (str, optional): Metric name used for evaluation.
result_name (str): Result name in the metric prefix. Default: 'bbox'.
result_name (str, optional): Result name in the metric prefix.
Default: 'img_bbox'. Default: 'img_bbox'.
Returns: Returns:
...@@ -448,13 +450,13 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -448,13 +450,13 @@ class NuScenesMonoDataset(CocoDataset):
Args: Args:
results (list[tuple | numpy.ndarray]): Testing results of the results (list[tuple | numpy.ndarray]): Testing results of the
dataset. dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
Returns: Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing \ tuple: (result_files, tmp_dir), result_files is a dict containing
the json filepaths, tmp_dir is the temporal directory created \ the json filepaths, tmp_dir is the temporal directory created
for saving json files when jsonfile_prefix is not specified. for saving json files when jsonfile_prefix is not specified.
""" """
assert isinstance(results, list), 'results must be a list' assert isinstance(results, list), 'results must be a list'
...@@ -504,15 +506,18 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -504,15 +506,18 @@ class NuScenesMonoDataset(CocoDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing Default: 'bbox'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
show (bool): Whether to visualize. result_names (list[str], optional): Result names in the
metric prefix. Default: ['img_bbox'].
show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
...@@ -576,7 +581,7 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -576,7 +581,7 @@ class NuScenesMonoDataset(CocoDataset):
"""Get data loading pipeline in self.show/evaluate function. """Get data loading pipeline in self.show/evaluate function.
Args: Args:
pipeline (list[dict] | None): Input pipeline. If None is given, \ pipeline (list[dict]): Input pipeline. If None is given,
get from self.pipeline. get from self.pipeline.
""" """
if pipeline is None: if pipeline is None:
...@@ -696,7 +701,7 @@ def cam_nusc_box_to_global(info, ...@@ -696,7 +701,7 @@ def cam_nusc_box_to_global(info,
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation. classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object. eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version. eval_version (str, optional): Evaluation version.
Default: 'detection_cvpr_2019' Default: 'detection_cvpr_2019'
Returns: Returns:
...@@ -736,7 +741,7 @@ def global_nusc_box_to_cam(info, ...@@ -736,7 +741,7 @@ def global_nusc_box_to_cam(info,
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation. classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object. eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version. eval_version (str, optional): Evaluation version.
Default: 'detection_cvpr_2019' Default: 'detection_cvpr_2019'
Returns: Returns:
...@@ -769,7 +774,7 @@ def nusc_box_to_cam_box3d(boxes): ...@@ -769,7 +774,7 @@ def nusc_box_to_cam_box3d(boxes):
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
Returns: Returns:
tuple (:obj:`CameraInstance3DBoxes` | torch.Tensor | torch.Tensor): \ tuple (:obj:`CameraInstance3DBoxes` | torch.Tensor | torch.Tensor):
Converted 3D bounding boxes, scores and labels. Converted 3D bounding boxes, scores and labels.
""" """
locs = torch.Tensor([b.center for b in boxes]).view(-1, 3) locs = torch.Tensor([b.center for b in boxes]).view(-1, 3)
......
...@@ -34,8 +34,8 @@ def box_collision_test(boxes, qboxes, clockwise=True): ...@@ -34,8 +34,8 @@ def box_collision_test(boxes, qboxes, clockwise=True):
Args: Args:
boxes (np.ndarray): Corners of current boxes. boxes (np.ndarray): Corners of current boxes.
qboxes (np.ndarray): Boxes to be avoid colliding. qboxes (np.ndarray): Boxes to be avoid colliding.
clockwise (bool): Whether the corners are in clockwise order. clockwise (bool, optional): Whether the corners are in
Default: True. clockwise order. Default: True.
""" """
N = boxes.shape[0] N = boxes.shape[0]
K = qboxes.shape[0] K = qboxes.shape[0]
...@@ -317,7 +317,7 @@ def box3d_transform_(boxes, loc_transform, rot_transform, valid_mask): ...@@ -317,7 +317,7 @@ def box3d_transform_(boxes, loc_transform, rot_transform, valid_mask):
boxes (np.ndarray): 3D boxes to be transformed. boxes (np.ndarray): 3D boxes to be transformed.
loc_transform (np.ndarray): Location transform to be applied. loc_transform (np.ndarray): Location transform to be applied.
rot_transform (np.ndarray): Rotation transform to be applied. rot_transform (np.ndarray): Rotation transform to be applied.
valid_mask (np.ndarray | None): Mask to indicate which boxes are valid. valid_mask (np.ndarray): Mask to indicate which boxes are valid.
""" """
num_box = boxes.shape[0] num_box = boxes.shape[0]
for i in range(num_box): for i in range(num_box):
...@@ -338,16 +338,17 @@ def noise_per_object_v3_(gt_boxes, ...@@ -338,16 +338,17 @@ def noise_per_object_v3_(gt_boxes,
Args: Args:
gt_boxes (np.ndarray): Ground truth boxes with shape (N, 7). gt_boxes (np.ndarray): Ground truth boxes with shape (N, 7).
points (np.ndarray | None): Input point cloud with shape (M, 4). points (np.ndarray, optional): Input point cloud with
Default: None. shape (M, 4). Default: None.
valid_mask (np.ndarray | None): Mask to indicate which boxes are valid. valid_mask (np.ndarray, optional): Mask to indicate which
Default: None. boxes are valid. Default: None.
rotation_perturb (float): Rotation perturbation. Default: pi / 4. rotation_perturb (float, optional): Rotation perturbation.
center_noise_std (float): Center noise standard deviation. Default: pi / 4.
center_noise_std (float, optional): Center noise standard deviation.
Default: 1.0. Default: 1.0.
global_random_rot_range (float): Global random rotation range. global_random_rot_range (float, optional): Global random rotation
Default: pi/4. range. Default: pi/4.
num_try (int): Number of try. Default: 100. num_try (int, optional): Number of try. Default: 100.
""" """
num_boxes = gt_boxes.shape[0] num_boxes = gt_boxes.shape[0]
if not isinstance(rotation_perturb, (list, tuple, np.ndarray)): if not isinstance(rotation_perturb, (list, tuple, np.ndarray)):
......
...@@ -15,10 +15,10 @@ class BatchSampler: ...@@ -15,10 +15,10 @@ class BatchSampler:
Args: Args:
sample_list (list[dict]): List of samples. sample_list (list[dict]): List of samples.
name (str | None): The category of samples. Default: None. name (str, optional): The category of samples. Default: None.
epoch (int | None): Sampling epoch. Default: None. epoch (int, optional): Sampling epoch. Default: None.
shuffle (bool): Whether to shuffle indices. Default: False. shuffle (bool, optional): Whether to shuffle indices. Default: False.
drop_reminder (bool): Drop reminder. Default: False. drop_reminder (bool, optional): Drop reminder. Default: False.
""" """
def __init__(self, def __init__(self,
...@@ -87,9 +87,9 @@ class DataBaseSampler(object): ...@@ -87,9 +87,9 @@ class DataBaseSampler(object):
rate (float): Rate of actual sampled over maximum sampled number. rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value. prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers. sample_groups (dict): Sampled classes and numbers.
classes (list[str]): List of classes. Default: None. classes (list[str], optional): List of classes. Default: None.
points_loader(dict): Config of points loader. Default: dict( points_loader(dict, optional): Config of points loader. Default:
type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3]) dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
""" """
def __init__(self, def __init__(self,
...@@ -198,9 +198,9 @@ class DataBaseSampler(object): ...@@ -198,9 +198,9 @@ class DataBaseSampler(object):
Returns: Returns:
dict: Dict of sampled 'pseudo ground truths'. dict: Dict of sampled 'pseudo ground truths'.
- gt_labels_3d (np.ndarray): ground truths labels \ - gt_labels_3d (np.ndarray): ground truths labels
of sampled objects. of sampled objects.
- gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`BaseInstance3DBoxes`):
sampled ground truth 3D bounding boxes sampled ground truth 3D bounding boxes
- points (np.ndarray): sampled points - points (np.ndarray): sampled points
- group_ids (np.ndarray): ids of sampled ground truths - group_ids (np.ndarray): ids of sampled ground truths
......
...@@ -24,7 +24,7 @@ class DefaultFormatBundle(object): ...@@ -24,7 +24,7 @@ class DefaultFormatBundle(object):
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
- gt_labels: (1)to tensor, (2)to DataContainer - gt_labels: (1)to tensor, (2)to DataContainer
- gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
(3)to DataContainer (stack=True) (3)to DataContainer (stack=True)
""" """
...@@ -92,8 +92,8 @@ class Collect3D(object): ...@@ -92,8 +92,8 @@ class Collect3D(object):
The "img_meta" item is always populated. The contents of the "img_meta" The "img_meta" item is always populated. The contents of the "img_meta"
dictionary depends on "meta_keys". By default this includes: dictionary depends on "meta_keys". By default this includes:
- 'img_shape': shape of the image input to the network as a tuple \ - 'img_shape': shape of the image input to the network as a tuple
(h, w, c). Note that images may be zero padded on the \ (h, w, c). Note that images may be zero padded on the
bottom/right if the batch tensor is larger than this shape. bottom/right if the batch tensor is larger than this shape.
- 'scale_factor': a float indicating the preprocessing scale - 'scale_factor': a float indicating the preprocessing scale
- 'flip': a boolean indicating if image flip transform was used - 'flip': a boolean indicating if image flip transform was used
...@@ -103,9 +103,9 @@ class Collect3D(object): ...@@ -103,9 +103,9 @@ class Collect3D(object):
- 'lidar2img': transform from lidar to image - 'lidar2img': transform from lidar to image
- 'depth2img': transform from depth to image - 'depth2img': transform from depth to image
- 'cam2img': transform from camera to image - 'cam2img': transform from camera to image
- 'pcd_horizontal_flip': a boolean indicating if point cloud is \ - 'pcd_horizontal_flip': a boolean indicating if point cloud is
flipped horizontally flipped horizontally
- 'pcd_vertical_flip': a boolean indicating if point cloud is \ - 'pcd_vertical_flip': a boolean indicating if point cloud is
flipped vertically flipped vertically
- 'box_mode_3d': 3D box mode - 'box_mode_3d': 3D box mode
- 'box_type_3d': 3D box type - 'box_type_3d': 3D box type
...@@ -130,16 +130,15 @@ class Collect3D(object): ...@@ -130,16 +130,15 @@ class Collect3D(object):
'sample_idx', 'pcd_scale_factor', 'pcd_rotation', 'pts_filename') 'sample_idx', 'pcd_scale_factor', 'pcd_rotation', 'pts_filename')
""" """
def __init__(self, def __init__(
keys, self,
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img', keys,
'depth2img', 'cam2img', 'pad_shape', meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
'scale_factor', 'flip', 'pcd_horizontal_flip', 'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', 'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d',
'img_norm_cfg', 'pcd_trans', 'sample_idx', 'box_type_3d', 'img_norm_cfg', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation', 'pcd_scale_factor', 'pcd_rotation', 'pcd_rotation_angle',
'pcd_rotation_angle', 'pts_filename', 'pts_filename', 'transformation_3d_flow')):
'transformation_3d_flow')):
self.keys = keys self.keys = keys
self.meta_keys = meta_keys self.meta_keys = meta_keys
......
...@@ -14,9 +14,10 @@ class LoadMultiViewImageFromFiles(object): ...@@ -14,9 +14,10 @@ class LoadMultiViewImageFromFiles(object):
Expects results['img_filename'] to be a list of filenames. Expects results['img_filename'] to be a list of filenames.
Args: Args:
to_float32 (bool): Whether to convert the img to float32. to_float32 (bool, optional): Whether to convert the img to float32.
Defaults to False. Defaults to False.
color_type (str): Color type of the file. Defaults to 'unchanged'. color_type (str, optional): Color type of the file.
Defaults to 'unchanged'.
""" """
def __init__(self, to_float32=False, color_type='unchanged'): def __init__(self, to_float32=False, color_type='unchanged'):
...@@ -30,7 +31,7 @@ class LoadMultiViewImageFromFiles(object): ...@@ -30,7 +31,7 @@ class LoadMultiViewImageFromFiles(object):
results (dict): Result dict containing multi-view image filenames. results (dict): Result dict containing multi-view image filenames.
Returns: Returns:
dict: The result dict containing the multi-view image data. \ dict: The result dict containing the multi-view image data.
Added keys and values are described below. Added keys and values are described below.
- filename (str): Multi-view image filenames. - filename (str): Multi-view image filenames.
...@@ -77,7 +78,7 @@ class LoadImageFromFileMono3D(LoadImageFromFile): ...@@ -77,7 +78,7 @@ class LoadImageFromFileMono3D(LoadImageFromFile):
detection, additional camera parameters need to be loaded. detection, additional camera parameters need to be loaded.
Args: Args:
kwargs (dict): Arguments are the same as those in \ kwargs (dict): Arguments are the same as those in
:class:`LoadImageFromFile`. :class:`LoadImageFromFile`.
""" """
...@@ -102,17 +103,20 @@ class LoadPointsFromMultiSweeps(object): ...@@ -102,17 +103,20 @@ class LoadPointsFromMultiSweeps(object):
This is usually used for nuScenes dataset to utilize previous sweeps. This is usually used for nuScenes dataset to utilize previous sweeps.
Args: Args:
sweeps_num (int): Number of sweeps. Defaults to 10. sweeps_num (int, optional): Number of sweeps. Defaults to 10.
load_dim (int): Dimension number of the loaded points. Defaults to 5. load_dim (int, optional): Dimension number of the loaded points.
use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4]. Defaults to 5.
file_client_args (dict): Config dict of file clients, refer to use_dim (list[int], optional): Which dimension to use.
Defaults to [0, 1, 2, 4].
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool): Whether to repeat keyframe when pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
sweeps is empty. Defaults to False. sweeps is empty. Defaults to False.
remove_close (bool): Whether to remove close points. remove_close (bool, optional): Whether to remove close points.
Defaults to False. Defaults to False.
test_mode (bool): If test_model=True used for testing, it will not test_mode (bool, optional): If `test_mode=True`, it will not
randomly sample sweeps but select the nearest N frames. randomly sample sweeps but select the nearest N frames.
Defaults to False. Defaults to False.
""" """
...@@ -161,7 +165,7 @@ class LoadPointsFromMultiSweeps(object): ...@@ -161,7 +165,7 @@ class LoadPointsFromMultiSweeps(object):
Args: Args:
points (np.ndarray | :obj:`BasePoints`): Sweep points. points (np.ndarray | :obj:`BasePoints`): Sweep points.
radius (float): Radius below which points are removed. radius (float, optional): Radius below which points are removed.
Defaults to 1.0. Defaults to 1.0.
Returns: Returns:
...@@ -182,14 +186,14 @@ class LoadPointsFromMultiSweeps(object): ...@@ -182,14 +186,14 @@ class LoadPointsFromMultiSweeps(object):
"""Call function to load multi-sweep point clouds from files. """Call function to load multi-sweep point clouds from files.
Args: Args:
results (dict): Result dict containing multi-sweep point cloud \ results (dict): Result dict containing multi-sweep point cloud
filenames. filenames.
Returns: Returns:
dict: The result dict containing the multi-sweep points data. \ dict: The result dict containing the multi-sweep points data.
Added key and value are described below. Added key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point \ - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
cloud arrays. cloud arrays.
""" """
points = results['points'] points = results['points']
...@@ -243,8 +247,8 @@ class PointSegClassMapping(object): ...@@ -243,8 +247,8 @@ class PointSegClassMapping(object):
Args: Args:
valid_cat_ids (tuple[int]): A tuple of valid category. valid_cat_ids (tuple[int]): A tuple of valid category.
max_cat_id (int): The max possible cat_id in input segmentation mask. max_cat_id (int, optional): The max possible cat_id in input
Defaults to 40. segmentation mask. Defaults to 40.
""" """
def __init__(self, valid_cat_ids, max_cat_id=40): def __init__(self, valid_cat_ids, max_cat_id=40):
...@@ -268,7 +272,7 @@ class PointSegClassMapping(object): ...@@ -268,7 +272,7 @@ class PointSegClassMapping(object):
results (dict): Result dict containing point semantic masks. results (dict): Result dict containing point semantic masks.
Returns: Returns:
dict: The result dict containing the mapped category ids. \ dict: The result dict containing the mapped category ids.
Updated key and value are described below. Updated key and value are described below.
- pts_semantic_mask (np.ndarray): Mapped semantic masks. - pts_semantic_mask (np.ndarray): Mapped semantic masks.
...@@ -307,7 +311,7 @@ class NormalizePointsColor(object): ...@@ -307,7 +311,7 @@ class NormalizePointsColor(object):
results (dict): Result dict containing point clouds data. results (dict): Result dict containing point clouds data.
Returns: Returns:
dict: The result dict containing the normalized points. \ dict: The result dict containing the normalized points.
Updated key and value are described below. Updated key and value are described below.
- points (:obj:`BasePoints`): Points after color normalization. - points (:obj:`BasePoints`): Points after color normalization.
...@@ -342,14 +346,17 @@ class LoadPointsFromFile(object): ...@@ -342,14 +346,17 @@ class LoadPointsFromFile(object):
- 'LIDAR': Points in LiDAR coordinates. - 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset. - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates. - 'CAMERA': Points in camera coordinates.
load_dim (int): The dimension of the loaded points. load_dim (int, optional): The dimension of the loaded points.
Defaults to 6. Defaults to 6.
use_dim (list[int]): Which dimensions of the points to be used. use_dim (list[int], optional): Which dimensions of the points to use.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension. or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool): Whether to use shifted height. Defaults to False. shift_height (bool, optional): Whether to use shifted height.
use_color (bool): Whether to use color features. Defaults to False. Defaults to False.
file_client_args (dict): Config dict of file clients, refer to use_color (bool, optional): Whether to use color features.
Defaults to False.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
""" """
...@@ -405,7 +412,7 @@ class LoadPointsFromFile(object): ...@@ -405,7 +412,7 @@ class LoadPointsFromFile(object):
results (dict): Result dict containing point clouds data. results (dict): Result dict containing point clouds data.
Returns: Returns:
dict: The result dict containing the point clouds data. \ dict: The result dict containing the point clouds data.
Added key and value are described below. Added key and value are described below.
- points (:obj:`BasePoints`): Point clouds data. - points (:obj:`BasePoints`): Point clouds data.
......
...@@ -16,18 +16,19 @@ class MultiScaleFlipAug3D(object): ...@@ -16,18 +16,19 @@ class MultiScaleFlipAug3D(object):
img_scale (tuple | list[tuple]: Images scales for resizing. img_scale (tuple | list[tuple]: Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for pts_scale_ratio (float | list[float]): Points scale ratios for
resizing. resizing.
flip (bool): Whether apply flip augmentation. Defaults to False. flip (bool, optional): Whether apply flip augmentation.
flip_direction (str | list[str]): Flip augmentation directions Defaults to False.
for images, options are "horizontal" and "vertical". flip_direction (str | list[str], optional): Flip augmentation
directions for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when ``flip == False``. be applied. It has no effect when ``flip == False``.
Defaults to "horizontal". Defaults to "horizontal".
pcd_horizontal_flip (bool): Whether apply horizontal flip augmentation pcd_horizontal_flip (bool, optional): Whether apply horizontal
to point cloud. Defaults to True. Note that it works only when flip augmentation to point cloud. Defaults to True.
'flip' is turned on. Note that it works only when 'flip' is turned on.
pcd_vertical_flip (bool): Whether apply vertical flip augmentation pcd_vertical_flip (bool, optional): Whether apply vertical flip
to point cloud. Defaults to True. Note that it works only when augmentation to point cloud. Defaults to True.
'flip' is turned on. Note that it works only when 'flip' is turned on.
""" """
def __init__(self, def __init__(self,
...@@ -70,7 +71,7 @@ class MultiScaleFlipAug3D(object): ...@@ -70,7 +71,7 @@ class MultiScaleFlipAug3D(object):
results (dict): Result dict contains the data to augment. results (dict): Result dict contains the data to augment.
Returns: Returns:
dict: The result dict contains the data that is augmented with \ dict: The result dict contains the data that is augmented with
different scales and flips. different scales and flips.
""" """
aug_data = [] aug_data = []
......
...@@ -22,7 +22,7 @@ class RandomDropPointsColor(object): ...@@ -22,7 +22,7 @@ class RandomDropPointsColor(object):
util/transform.py#L223>`_ for more details. util/transform.py#L223>`_ for more details.
Args: Args:
drop_ratio (float): The probability of dropping point colors. drop_ratio (float, optional): The probability of dropping point colors.
Defaults to 0.2. Defaults to 0.2.
""" """
...@@ -38,7 +38,7 @@ class RandomDropPointsColor(object): ...@@ -38,7 +38,7 @@ class RandomDropPointsColor(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after color dropping, \ dict: Results after color dropping,
'points' key is updated in the result dict. 'points' key is updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -105,10 +105,11 @@ class RandomFlip3D(RandomFlip): ...@@ -105,10 +105,11 @@ class RandomFlip3D(RandomFlip):
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
direction (str): Flip direction. Default: horizontal. direction (str, optional): Flip direction.
Default: 'horizontal'.
Returns: Returns:
dict: Flipped results, 'points', 'bbox3d_fields' keys are \ dict: Flipped results, 'points', 'bbox3d_fields' keys are
updated in the result dict. updated in the result dict.
""" """
assert direction in ['horizontal', 'vertical'] assert direction in ['horizontal', 'vertical']
...@@ -137,15 +138,15 @@ class RandomFlip3D(RandomFlip): ...@@ -137,15 +138,15 @@ class RandomFlip3D(RandomFlip):
input_dict['cam2img'][0][2] = w - input_dict['cam2img'][0][2] input_dict['cam2img'][0][2] = w - input_dict['cam2img'][0][2]
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to flip points, values in the ``bbox3d_fields`` and \ """Call function to flip points, values in the ``bbox3d_fields`` and
also flip 2D image and its annotations. also flip 2D image and its annotations.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Flipped results, 'flip', 'flip_direction', \ dict: Flipped results, 'flip', 'flip_direction',
'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added \ 'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added
into result dict. into result dict.
""" """
# filp 2D image and its annotations # filp 2D image and its annotations
...@@ -187,20 +188,20 @@ class RandomFlip3D(RandomFlip): ...@@ -187,20 +188,20 @@ class RandomFlip3D(RandomFlip):
class RandomJitterPoints(object): class RandomJitterPoints(object):
"""Randomly jitter point coordinates. """Randomly jitter point coordinates.
Different from the global translation in ``GlobalRotScaleTrans``, here we \ Different from the global translation in ``GlobalRotScaleTrans``, here we
apply different noises to each point in a scene. apply different noises to each point in a scene.
Args: Args:
jitter_std (list[float]): The standard deviation of jittering noise. jitter_std (list[float]): The standard deviation of jittering noise.
This applies random noise to all points in a 3D scene, which is \ This applies random noise to all points in a 3D scene, which is
sampled from a gaussian distribution whose standard deviation is \ sampled from a gaussian distribution whose standard deviation is
set by ``jitter_std``. Defaults to [0.01, 0.01, 0.01] set by ``jitter_std``. Defaults to [0.01, 0.01, 0.01]
clip_range (list[float] | None): Clip the randomly generated jitter \ clip_range (list[float]): Clip the randomly generated jitter
noise into this range. If None is given, don't perform clipping. noise into this range. If None is given, don't perform clipping.
Defaults to [-0.05, 0.05] Defaults to [-0.05, 0.05]
Note: Note:
This transform should only be used in point cloud segmentation tasks \ This transform should only be used in point cloud segmentation tasks
because we don't transform ground-truth bboxes accordingly. because we don't transform ground-truth bboxes accordingly.
For similar transform in detection task, please refer to `ObjectNoise`. For similar transform in detection task, please refer to `ObjectNoise`.
""" """
...@@ -229,7 +230,7 @@ class RandomJitterPoints(object): ...@@ -229,7 +230,7 @@ class RandomJitterPoints(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after adding noise to each point, \ dict: Results after adding noise to each point,
'points' key is updated in the result dict. 'points' key is updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -291,8 +292,8 @@ class ObjectSample(object): ...@@ -291,8 +292,8 @@ class ObjectSample(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after object sampling augmentation, \ dict: Results after object sampling augmentation,
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated \ 'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict. in the result dict.
""" """
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
...@@ -388,7 +389,7 @@ class ObjectNoise(object): ...@@ -388,7 +389,7 @@ class ObjectNoise(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after adding noise to each object, \ dict: Results after adding noise to each object,
'points', 'gt_bboxes_3d' keys are updated in the result dict. 'points', 'gt_bboxes_3d' keys are updated in the result dict.
""" """
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
...@@ -428,10 +429,10 @@ class GlobalAlignment(object): ...@@ -428,10 +429,10 @@ class GlobalAlignment(object):
rotation_axis (int): Rotation axis for points and bboxes rotation. rotation_axis (int): Rotation axis for points and bboxes rotation.
Note: Note:
We do not record the applied rotation and translation as in \ We do not record the applied rotation and translation as in
GlobalRotScaleTrans. Because usually, we do not need to reverse \ GlobalRotScaleTrans. Because usually, we do not need to reverse
the alignment step. the alignment step.
For example, ScanNet 3D detection task uses aligned ground-truth \ For example, ScanNet 3D detection task uses aligned ground-truth
bounding boxes for evaluation. bounding boxes for evaluation.
""" """
...@@ -483,7 +484,7 @@ class GlobalAlignment(object): ...@@ -483,7 +484,7 @@ class GlobalAlignment(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after global alignment, 'points' and keys in \ dict: Results after global alignment, 'points' and keys in
input_dict['bbox3d_fields'] are updated in the result dict. input_dict['bbox3d_fields'] are updated in the result dict.
""" """
assert 'axis_align_matrix' in input_dict['ann_info'].keys(), \ assert 'axis_align_matrix' in input_dict['ann_info'].keys(), \
...@@ -512,15 +513,15 @@ class GlobalRotScaleTrans(object): ...@@ -512,15 +513,15 @@ class GlobalRotScaleTrans(object):
"""Apply global rotation, scaling and translation to a 3D scene. """Apply global rotation, scaling and translation to a 3D scene.
Args: Args:
rot_range (list[float]): Range of rotation angle. rot_range (list[float], optional): Range of rotation angle.
Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]). Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
scale_ratio_range (list[float]): Range of scale ratio. scale_ratio_range (list[float], optional): Range of scale ratio.
Defaults to [0.95, 1.05]. Defaults to [0.95, 1.05].
translation_std (list[float]): The standard deviation of translation translation_std (list[float], optional): The standard deviation of
noise. This applies random translation to a scene by a noise, which translation noise applied to a scene, which
is sampled from a gaussian distribution whose standard deviation is sampled from a gaussian distribution whose standard deviation
is set by ``translation_std``. Defaults to [0, 0, 0] is set by ``translation_std``. Defaults to [0, 0, 0]
shift_height (bool): Whether to shift height. shift_height (bool, optional): Whether to shift height.
(the fourth dimension of indoor points) when scaling. (the fourth dimension of indoor points) when scaling.
Defaults to False. Defaults to False.
""" """
...@@ -559,8 +560,8 @@ class GlobalRotScaleTrans(object): ...@@ -559,8 +560,8 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after translation, 'points', 'pcd_trans' \ dict: Results after translation, 'points', 'pcd_trans'
and keys in input_dict['bbox3d_fields'] are updated \ and keys in input_dict['bbox3d_fields'] are updated
in the result dict. in the result dict.
""" """
translation_std = np.array(self.translation_std, dtype=np.float32) translation_std = np.array(self.translation_std, dtype=np.float32)
...@@ -578,8 +579,8 @@ class GlobalRotScaleTrans(object): ...@@ -578,8 +579,8 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after rotation, 'points', 'pcd_rotation' \ dict: Results after rotation, 'points', 'pcd_rotation'
and keys in input_dict['bbox3d_fields'] are updated \ and keys in input_dict['bbox3d_fields'] are updated
in the result dict. in the result dict.
""" """
rotation = self.rot_range rotation = self.rot_range
...@@ -608,7 +609,7 @@ class GlobalRotScaleTrans(object): ...@@ -608,7 +609,7 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after scaling, 'points'and keys in \ dict: Results after scaling, 'points'and keys in
input_dict['bbox3d_fields'] are updated in the result dict. input_dict['bbox3d_fields'] are updated in the result dict.
""" """
scale = input_dict['pcd_scale_factor'] scale = input_dict['pcd_scale_factor']
...@@ -630,7 +631,7 @@ class GlobalRotScaleTrans(object): ...@@ -630,7 +631,7 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after scaling, 'pcd_scale_factor' are updated \ dict: Results after scaling, 'pcd_scale_factor' are updated
in the result dict. in the result dict.
""" """
scale_factor = np.random.uniform(self.scale_ratio_range[0], scale_factor = np.random.uniform(self.scale_ratio_range[0],
...@@ -638,7 +639,7 @@ class GlobalRotScaleTrans(object): ...@@ -638,7 +639,7 @@ class GlobalRotScaleTrans(object):
input_dict['pcd_scale_factor'] = scale_factor input_dict['pcd_scale_factor'] = scale_factor
def __call__(self, input_dict): def __call__(self, input_dict):
"""Private function to rotate, scale and translate bounding boxes and \ """Private function to rotate, scale and translate bounding boxes and
points. points.
Args: Args:
...@@ -646,7 +647,7 @@ class GlobalRotScaleTrans(object): ...@@ -646,7 +647,7 @@ class GlobalRotScaleTrans(object):
Returns: Returns:
dict: Results after scaling, 'points', 'pcd_rotation', dict: Results after scaling, 'points', 'pcd_rotation',
'pcd_scale_factor', 'pcd_trans' and keys in \ 'pcd_scale_factor', 'pcd_trans' and keys in
input_dict['bbox3d_fields'] are updated in the result dict. input_dict['bbox3d_fields'] are updated in the result dict.
""" """
if 'transformation_3d_flow' not in input_dict: if 'transformation_3d_flow' not in input_dict:
...@@ -684,7 +685,7 @@ class PointShuffle(object): ...@@ -684,7 +685,7 @@ class PointShuffle(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \ dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
idx = input_dict['points'].shuffle() idx = input_dict['points'].shuffle()
...@@ -723,7 +724,7 @@ class ObjectRangeFilter(object): ...@@ -723,7 +724,7 @@ class ObjectRangeFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict. keys are updated in the result dict.
""" """
# Check points instance type and initialise bev_range # Check points instance type and initialise bev_range
...@@ -775,7 +776,7 @@ class PointsRangeFilter(object): ...@@ -775,7 +776,7 @@ class PointsRangeFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \ dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -821,7 +822,7 @@ class ObjectNameFilter(object): ...@@ -821,7 +822,7 @@ class ObjectNameFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict. keys are updated in the result dict.
""" """
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
...@@ -913,7 +914,7 @@ class PointSample(object): ...@@ -913,7 +914,7 @@ class PointSample(object):
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = results['points']
...@@ -994,10 +995,10 @@ class IndoorPatchPointSample(object): ...@@ -994,10 +995,10 @@ class IndoorPatchPointSample(object):
additional features. Defaults to False. additional features. Defaults to False.
num_try (int, optional): Number of times to try if the patch selected num_try (int, optional): Number of times to try if the patch selected
is invalid. Defaults to 10. is invalid. Defaults to 10.
enlarge_size (float | None, optional): Enlarge the sampled patch to enlarge_size (float, optional): Enlarge the sampled patch to
[-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as [-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as
an augmentation. If None, set it as 0. Defaults to 0.2. an augmentation. If None, set it as 0. Defaults to 0.2.
min_unique_num (int | None, optional): Minimum number of unique points min_unique_num (int, optional): Minimum number of unique points
the sampled patch should contain. If None, use PointNet++'s method the sampled patch should contain. If None, use PointNet++'s method
to judge uniqueness. Defaults to None. to judge uniqueness. Defaults to None.
eps (float, optional): A value added to patch boundary to guarantee eps (float, optional): A value added to patch boundary to guarantee
...@@ -1038,7 +1039,7 @@ class IndoorPatchPointSample(object): ...@@ -1038,7 +1039,7 @@ class IndoorPatchPointSample(object):
attribute_dims, point_type): attribute_dims, point_type):
"""Generating model input. """Generating model input.
Generate input by subtracting patch center and adding additional \ Generate input by subtracting patch center and adding additional
features. Currently support colors and normalized xyz as features. features. Currently support colors and normalized xyz as features.
Args: Args:
...@@ -1182,7 +1183,7 @@ class IndoorPatchPointSample(object): ...@@ -1182,7 +1183,7 @@ class IndoorPatchPointSample(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = results['points']
...@@ -1242,7 +1243,7 @@ class BackgroundPointsFilter(object): ...@@ -1242,7 +1243,7 @@ class BackgroundPointsFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \ dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -1340,7 +1341,7 @@ class VoxelBasedPointSampler(object): ...@@ -1340,7 +1341,7 @@ class VoxelBasedPointSampler(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = results['points']
......
...@@ -78,13 +78,13 @@ class ScanNetDataset(Custom3DDataset): ...@@ -78,13 +78,13 @@ class ScanNetDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index. - sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds. - pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds. - file_name (str): Filename of point clouds.
- img_prefix (str | None, optional): Prefix of image files. - img_prefix (str, optional): Prefix of image files.
- img_info (dict, optional): Image info. - img_info (dict, optional): Image info.
- ann_info (dict): Annotation info. - ann_info (dict): Annotation info.
""" """
...@@ -129,12 +129,12 @@ class ScanNetDataset(Custom3DDataset): ...@@ -129,12 +129,12 @@ class ScanNetDataset(Custom3DDataset):
Returns: Returns:
dict: annotation information consists of the following keys: dict: annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`DepthInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`DepthInstance3DBoxes`):
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- pts_instance_mask_path (str): Path of instance masks. - pts_instance_mask_path (str): Path of instance masks.
- pts_semantic_mask_path (str): Path of semantic masks. - pts_semantic_mask_path (str): Path of semantic masks.
- axis_align_matrix (np.ndarray): Transformation matrix for \ - axis_align_matrix (np.ndarray): Transformation matrix for
global scene alignment. global scene alignment.
""" """
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
...@@ -172,7 +172,7 @@ class ScanNetDataset(Custom3DDataset): ...@@ -172,7 +172,7 @@ class ScanNetDataset(Custom3DDataset):
def prepare_test_data(self, index): def prepare_test_data(self, index):
"""Prepare data for testing. """Prepare data for testing.
We should take axis_align_matrix from self.data_infos since we need \ We should take axis_align_matrix from self.data_infos since we need
to align point clouds. to align point clouds.
Args: Args:
...@@ -272,7 +272,7 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -272,7 +272,7 @@ class ScanNetSegDataset(Custom3DSegDataset):
as input. Defaults to None. as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g. \ ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES). unannotated points. If None is given, set to len(self.CLASSES).
Defaults to None. Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load scene_idxs (np.ndarray | str, optional): Precomputed index to load
...@@ -424,7 +424,7 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -424,7 +424,7 @@ class ScanNetSegDataset(Custom3DSegDataset):
Args: Args:
outputs (list[dict]): Testing results of the dataset. outputs (list[dict]): Testing results of the dataset.
txtfile_prefix (str | None): The prefix of saved files. It includes txtfile_prefix (str): The prefix of saved files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
......
...@@ -74,13 +74,13 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -74,13 +74,13 @@ class SUNRGBDDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index. - sample_idx (str): Sample index.
- pts_filename (str, optional): Filename of point clouds. - pts_filename (str, optional): Filename of point clouds.
- file_name (str, optional): Filename of point clouds. - file_name (str, optional): Filename of point clouds.
- img_prefix (str | None, optional): Prefix of image files. - img_prefix (str, optional): Prefix of image files.
- img_info (dict, optional): Image info. - img_info (dict, optional): Image info.
- calib (dict, optional): Camera calibration info. - calib (dict, optional): Camera calibration info.
- ann_info (dict): Annotation info. - ann_info (dict): Annotation info.
...@@ -125,7 +125,7 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -125,7 +125,7 @@ class SUNRGBDDataset(Custom3DDataset):
Returns: Returns:
dict: annotation information consists of the following keys: dict: annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`DepthInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`DepthInstance3DBoxes`):
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- pts_instance_mask_path (str): Path of instance masks. - pts_instance_mask_path (str): Path of instance masks.
...@@ -239,12 +239,15 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -239,12 +239,15 @@ class SUNRGBDDataset(Custom3DDataset):
Args: Args:
results (list[dict]): List of results. results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
iou_thr (list[float]): AP IoU thresholds. Default: None.
iou_thr_2d (list[float]): AP IoU thresholds for 2d evaluation. iou_thr (list[float], optional): AP IoU thresholds for 3D
show (bool): Whether to visualize. evaluation. Default: (0.25, 0.5).
iou_thr_2d (list[float], optional): AP IoU thresholds for 2D
evaluation. Default: (0.5, ).
show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
......
...@@ -25,7 +25,7 @@ def is_loading_function(transform): ...@@ -25,7 +25,7 @@ def is_loading_function(transform):
transform (dict | :obj:`Pipeline`): A transform config or a function. transform (dict | :obj:`Pipeline`): A transform config or a function.
Returns: Returns:
bool | None: Whether it is a loading function. None means can't judge. bool: Whether it is a loading function. None means can't judge.
When transform is `MultiScaleFlipAug3D`, we return None. When transform is `MultiScaleFlipAug3D`, we return None.
""" """
# TODO: use more elegant way to distinguish loading modules # TODO: use more elegant way to distinguish loading modules
...@@ -92,7 +92,7 @@ def get_loading_pipeline(pipeline): ...@@ -92,7 +92,7 @@ def get_loading_pipeline(pipeline):
... dict(type='Collect3D', ... dict(type='Collect3D',
... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d'])
... ] ... ]
>>> assert expected_pipelines ==\ >>> assert expected_pipelines == \
... get_loading_pipeline(pipelines) ... get_loading_pipeline(pipelines)
""" """
loading_pipeline = [] loading_pipeline = []
...@@ -126,7 +126,7 @@ def extract_result_dict(results, key): ...@@ -126,7 +126,7 @@ def extract_result_dict(results, key):
key (str): Key of the desired data. key (str): Key of the desired data.
Returns: Returns:
np.ndarray | torch.Tensor | None: Data term. np.ndarray | torch.Tensor: Data term.
""" """
if key not in results.keys(): if key not in results.keys():
return None return None
......
...@@ -46,8 +46,9 @@ class WaymoDataset(KittiDataset): ...@@ -46,8 +46,9 @@ class WaymoDataset(KittiDataset):
Defaults to True. Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
pcd_limit_range (list): The range of point cloud used to filter pcd_limit_range (list(float), optional): The range of point cloud used
invalid predicted boxes. Default: [-85, -85, -5, 85, 85, 5]. to filter invalid predicted boxes.
Default: [-85, -85, -5, 85, 85, 5].
""" """
CLASSES = ('Car', 'Cyclist', 'Pedestrian') CLASSES = ('Car', 'Cyclist', 'Pedestrian')
...@@ -100,7 +101,7 @@ class WaymoDataset(KittiDataset): ...@@ -100,7 +101,7 @@ class WaymoDataset(KittiDataset):
- sample_idx (str): sample index - sample_idx (str): sample index
- pts_filename (str): filename of point clouds - pts_filename (str): filename of point clouds
- img_prefix (str | None): prefix of image files - img_prefix (str): prefix of image files
- img_info (dict): image info - img_info (dict): image info
- lidar2img (list[np.ndarray], optional): transformations from - lidar2img (list[np.ndarray], optional): transformations from
lidar to different cameras lidar to different cameras
...@@ -140,15 +141,15 @@ class WaymoDataset(KittiDataset): ...@@ -140,15 +141,15 @@ class WaymoDataset(KittiDataset):
Args: Args:
outputs (list[dict]): Testing results of the dataset. outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str | None): The prefix of pkl files. It includes pklfile_prefix (str): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submitted files. It submission_prefix (str): The prefix of submitted files. It
includes the file path and the prefix of filename, e.g., includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created. "a/b/prefix". If not specified, a temp file will be created.
Default: None. Default: None.
data_format (str | None): Output data format. Default: 'waymo'. data_format (str, optional): Output data format.
Another supported choice is 'kitti'. Default: 'waymo'. Another supported choice is 'kitti'.
Returns: Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing tuple: (result_files, tmp_dir), result_files is a dict containing
...@@ -226,18 +227,18 @@ class WaymoDataset(KittiDataset): ...@@ -226,18 +227,18 @@ class WaymoDataset(KittiDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
Default: 'waymo'. Another supported metric is 'kitti'. Default: 'waymo'. Another supported metric is 'kitti'.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
pklfile_prefix (str | None): The prefix of pkl files. It includes pklfile_prefix (str, optional): The prefix of pkl files including
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submission datas. submission_prefix (str, optional): The prefix of submission datas.
If not specified, the submission data will not be generated. If not specified, the submission data will not be generated.
show (bool): Whether to visualize. show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
...@@ -364,8 +365,8 @@ class WaymoDataset(KittiDataset): ...@@ -364,8 +365,8 @@ class WaymoDataset(KittiDataset):
net_outputs (List[np.ndarray]): list of array storing the net_outputs (List[np.ndarray]): list of array storing the
bbox and score bbox and score
class_nanes (List[String]): A list of class names class_nanes (List[String]): A list of class names
pklfile_prefix (str | None): The prefix of pkl file. pklfile_prefix (str): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file. submission_prefix (str): The prefix of submission file.
Returns: Returns:
List[dict]: A list of dict have the kitti 3d format List[dict]: A list of dict have the kitti 3d format
......
...@@ -134,7 +134,7 @@ class PointNet2SAMSG(BasePointNet): ...@@ -134,7 +134,7 @@ class PointNet2SAMSG(BasePointNet):
- sa_xyz (torch.Tensor): The coordinates of sa features. - sa_xyz (torch.Tensor): The coordinates of sa features.
- sa_features (torch.Tensor): The features from the - sa_features (torch.Tensor): The features from the
last Set Aggregation Layers. last Set Aggregation Layers.
- sa_indices (torch.Tensor): Indices of the \ - sa_indices (torch.Tensor): Indices of the
input points. input points.
""" """
xyz, features = self._split_point_feats(points) xyz, features = self._split_point_feats(points)
......
...@@ -97,11 +97,11 @@ class PointNet2SASSG(BasePointNet): ...@@ -97,11 +97,11 @@ class PointNet2SASSG(BasePointNet):
Returns: Returns:
dict[str, list[torch.Tensor]]: Outputs after SA and FP modules. dict[str, list[torch.Tensor]]: Outputs after SA and FP modules.
- fp_xyz (list[torch.Tensor]): The coordinates of \ - fp_xyz (list[torch.Tensor]): The coordinates of
each fp features. each fp features.
- fp_features (list[torch.Tensor]): The features \ - fp_features (list[torch.Tensor]): The features
from each Feature Propagate Layers. from each Feature Propagate Layers.
- fp_indices (list[torch.Tensor]): Indices of the \ - fp_indices (list[torch.Tensor]): Indices of the
input points. input points.
""" """
xyz, features = self._split_point_feats(points) xyz, features = self._split_point_feats(points)
......
...@@ -13,17 +13,18 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -13,17 +13,18 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
Args: Args:
channels (int): Channels after modules, before conv_seg. channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes. num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.5. dropout_ratio (float, optional): Ratio of dropout layer. Default: 0.5.
conv_cfg (dict|None): Config of conv layers. conv_cfg (dict, optional): Config of conv layers.
Default: dict(type='Conv1d'). Default: dict(type='Conv1d').
norm_cfg (dict|None): Config of norm layers. norm_cfg (dict, optional): Config of norm layers.
Default: dict(type='BN1d'). Default: dict(type='BN1d').
act_cfg (dict): Config of activation layers. act_cfg (dict, optional): Config of activation layers.
Default: dict(type='ReLU'). Default: dict(type='ReLU').
loss_decode (dict): Config of decode loss. loss_decode (dict, optional): Config of decode loss.
Default: dict(type='CrossEntropyLoss'). Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using ignore_index (int, optional): The label index to be ignored.
masked BCE loss, ignore_index should be set to None. Default: 255. When using masked BCE loss, ignore_index should be set to None.
Default: 255.
""" """
def __init__(self, def __init__(self,
...@@ -110,9 +111,9 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -110,9 +111,9 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
"""Compute semantic segmentation loss. """Compute semantic segmentation loss.
Args: Args:
seg_logit (torch.Tensor): Predicted per-point segmentation logits \ seg_logit (torch.Tensor): Predicted per-point segmentation logits
of shape [B, num_classes, N]. of shape [B, num_classes, N].
seg_label (torch.Tensor): Ground-truth segmentation label of \ seg_label (torch.Tensor): Ground-truth segmentation label of
shape [B, N]. shape [B, N].
""" """
loss = dict() loss = dict()
......
...@@ -14,7 +14,7 @@ class PAConvHead(PointNet2Head): ...@@ -14,7 +14,7 @@ class PAConvHead(PointNet2Head):
Args: Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules. fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
fp_norm_cfg (dict|None): Config of norm layers used in FP modules. fp_norm_cfg (dict): Config of norm layers used in FP modules.
Default: dict(type='BN2d'). Default: dict(type='BN2d').
""" """
......
...@@ -16,7 +16,7 @@ class PointNet2Head(Base3DDecodeHead): ...@@ -16,7 +16,7 @@ class PointNet2Head(Base3DDecodeHead):
Args: Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules. fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
fp_norm_cfg (dict|None): Config of norm layers used in FP modules. fp_norm_cfg (dict): Config of norm layers used in FP modules.
Default: dict(type='BN2d'). Default: dict(type='BN2d').
""" """
......
...@@ -145,7 +145,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -145,7 +145,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
x (torch.Tensor): Input features. x (torch.Tensor): Input features.
Returns: Returns:
tuple[torch.Tensor]: Contain score of each class, bbox \ tuple[torch.Tensor]: Contain score of each class, bbox
regression and direction classification predictions. regression and direction classification predictions.
""" """
cls_score = self.conv_cls(x) cls_score = self.conv_cls(x)
...@@ -163,7 +163,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -163,7 +163,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
features produced by FPN. features produced by FPN.
Returns: Returns:
tuple[list[torch.Tensor]]: Multi-level class score, bbox \ tuple[list[torch.Tensor]]: Multi-level class score, bbox
and direction predictions. and direction predictions.
""" """
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
...@@ -177,7 +177,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -177,7 +177,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
device (str): device of current module. device (str): device of current module.
Returns: Returns:
list[list[torch.Tensor]]: Anchors of each image, valid flags \ list[list[torch.Tensor]]: Anchors of each image, valid flags
of each image. of each image.
""" """
num_imgs = len(input_metas) num_imgs = len(input_metas)
...@@ -207,7 +207,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -207,7 +207,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
num_total_samples (int): The number of valid samples. num_total_samples (int): The number of valid samples.
Returns: Returns:
tuple[torch.Tensor]: Losses of class, bbox \ tuple[torch.Tensor]: Losses of class, bbox
and direction, respectively. and direction, respectively.
""" """
# classification loss # classification loss
...@@ -285,7 +285,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -285,7 +285,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
the 7th dimension is rotation dimension. the 7th dimension is rotation dimension.
Returns: Returns:
tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th \ tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th
dimensions are changed. dimensions are changed.
""" """
rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos( rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
...@@ -318,16 +318,16 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -318,16 +318,16 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
of each sample. of each sample.
gt_labels (list[torch.Tensor]): Gt labels of each sample. gt_labels (list[torch.Tensor]): Gt labels of each sample.
input_metas (list[dict]): Contain pcd and img's meta info. input_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding. which bounding boxes to ignore.
Returns: Returns:
dict[str, list[torch.Tensor]]: Classification, bbox, and \ dict[str, list[torch.Tensor]]: Classification, bbox, and
direction losses of each level. direction losses of each level.
- loss_cls (list[torch.Tensor]): Classification losses. - loss_cls (list[torch.Tensor]): Classification losses.
- loss_bbox (list[torch.Tensor]): Box regression losses. - loss_bbox (list[torch.Tensor]): Box regression losses.
- loss_dir (list[torch.Tensor]): Direction classification \ - loss_dir (list[torch.Tensor]): Direction classification
losses. losses.
""" """
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
...@@ -385,7 +385,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -385,7 +385,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
dir_cls_preds (list[torch.Tensor]): Multi-level direction dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions. class predictions.
input_metas (list[dict]): Contain pcd and img's meta info. input_metas (list[dict]): Contain pcd and img's meta info.
cfg (None | :obj:`ConfigDict`): Training or testing config. cfg (:obj:`ConfigDict`): Training or testing config.
rescale (list[torch.Tensor]): Whether th rescale bbox. rescale (list[torch.Tensor]): Whether th rescale bbox.
Returns: Returns:
...@@ -439,7 +439,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin): ...@@ -439,7 +439,7 @@ class Anchor3DHead(BaseModule, AnchorTrainMixin):
mlvl_anchors (List[torch.Tensor]): Multi-level anchors mlvl_anchors (List[torch.Tensor]): Multi-level anchors
in single batch. in single batch.
input_meta (list[dict]): Contain pcd and img's meta info. input_meta (list[dict]): Contain pcd and img's meta info.
cfg (None | :obj:`ConfigDict`): Training or testing config. cfg (:obj:`ConfigDict`): Training or testing config.
rescale (list[torch.Tensor]): whether th rescale bbox. rescale (list[torch.Tensor]): whether th rescale bbox.
Returns: Returns:
......
...@@ -18,35 +18,45 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -18,35 +18,45 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
num_classes (int): Number of categories excluding the background num_classes (int): Number of categories excluding the background
category. category.
in_channels (int): Number of channels in the input feature map. in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes. feat_channels (int, optional): Number of hidden channels.
stacked_convs (int): Number of stacking convs of the head. Used in child classes. Defaults to 256.
strides (tuple): Downsample factor of each feature map. stacked_convs (int, optional): Number of stacking convs of the head.
dcn_on_last_conv (bool): If true, use dcn in the last layer of strides (tuple, optional): Downsample factor of each feature map.
towers. Default: False. dcn_on_last_conv (bool, optional): If true, use dcn in the last
conv_bias (bool | str): If specified as `auto`, it will be decided by layer of towers. Default: False.
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is conv_bias (bool | str, optional): If specified as `auto`, it will be
None, otherwise False. Default: "auto". decided by the norm_cfg. Bias of conv will be set as True
background_label (int | None): Label ID of background, set as 0 for if `norm_cfg` is None, otherwise False. Default: 'auto'.
RPN and num_classes for other heads. It will automatically set as background_label (int, optional): Label ID of background,
num_classes if None is given. set as 0 for RPN and num_classes for other heads.
use_direction_classifier (bool): Whether to add a direction classifier. It will automatically set as `num_classes` if None is given.
diff_rad_by_sin (bool): Whether to change the difference into sin use_direction_classifier (bool, optional):
difference for box regression loss. Whether to add a direction classifier.
loss_cls (dict): Config of classification loss. diff_rad_by_sin (bool, optional): Whether to change the difference
loss_bbox (dict): Config of localization loss. into sin difference for box regression loss. Defaults to True.
loss_dir (dict): Config of direction classifier loss. dir_offset (float, optional): Parameter used in direction
loss_attr (dict): Config of attribute classifier loss, which is only classification. Defaults to 0.
active when pred_attrs=True. dir_limit_offset (float, optional): Parameter used in direction
bbox_code_size (int): Dimensions of predicted bounding boxes. classification. Defaults to 0.
pred_attrs (bool): Whether to predict attributes. Default to False. loss_cls (dict, optional): Config of classification loss.
num_attrs (int): The number of attributes to be predicted. Default: 9. loss_bbox (dict, optional): Config of localization loss.
pred_velo (bool): Whether to predict velocity. Default to False. loss_dir (dict, optional): Config of direction classifier loss.
pred_bbox2d (bool): Whether to predict 2D boxes. Default to False. loss_attr (dict, optional): Config of attribute classifier loss,
group_reg_dims (tuple[int]): The dimension of each regression target which is only active when `pred_attrs=True`.
group. Default: (2, 1, 3, 1, 2). bbox_code_size (int, optional): Dimensions of predicted bounding boxes.
cls_branch (tuple[int]): Channels for classification branch. pred_attrs (bool, optional): Whether to predict attributes.
Defaults to False.
num_attrs (int, optional): The number of attributes to be predicted.
Default: 9.
pred_velo (bool, optional): Whether to predict velocity.
Defaults to False.
pred_bbox2d (bool, optional): Whether to predict 2D boxes.
Defaults to False.
group_reg_dims (tuple[int], optional): The dimension of each regression
target group. Default: (2, 1, 3, 1, 2).
cls_branch (tuple[int], optional): Channels for classification branch.
Default: (128, 64). Default: (128, 64).
reg_branch (tuple[tuple]): Channels for regression branch. reg_branch (tuple[tuple], optional): Channels for regression branch.
Default: ( Default: (
(128, 64), # offset (128, 64), # offset
(128, 64), # depth (128, 64), # depth
...@@ -54,14 +64,16 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -54,14 +64,16 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
(64, ), # rot (64, ), # rot
() # velo () # velo
), ),
dir_branch (tuple[int]): Channels for direction classification branch. dir_branch (tuple[int], optional): Channels for direction
classification branch. Default: (64, ).
attr_branch (tuple[int], optional): Channels for classification branch.
Default: (64, ). Default: (64, ).
attr_branch (tuple[int]): Channels for classification branch. conv_cfg (dict, optional): Config dict for convolution layer.
Default: (64, ). Default: None.
conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict, optional): Config dict for normalization layer.
norm_cfg (dict): Config dict for normalization layer. Default: None. Default: None.
train_cfg (dict): Training config of anchor head. train_cfg (dict, optional): Training config of anchor head.
test_cfg (dict): Testing config of anchor head. test_cfg (dict, optional): Testing config of anchor head.
""" # noqa: W605 """ # noqa: W605
_version = 1 _version = 1
...@@ -126,6 +138,7 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -126,6 +138,7 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
self.use_direction_classifier = use_direction_classifier self.use_direction_classifier = use_direction_classifier
self.diff_rad_by_sin = diff_rad_by_sin self.diff_rad_by_sin = diff_rad_by_sin
self.dir_offset = dir_offset self.dir_offset = dir_offset
self.dir_limit_offset = dir_limit_offset
self.loss_cls = build_loss(loss_cls) self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox) self.loss_bbox = build_loss(loss_bbox)
self.loss_dir = build_loss(loss_dir) self.loss_dir = build_loss(loss_dir)
...@@ -290,7 +303,7 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -290,7 +303,7 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
a 4D-tensor. a 4D-tensor.
Returns: Returns:
tuple: Usually contain classification scores, bbox predictions, \ tuple: Usually contain classification scores, bbox predictions,
and direction class predictions. and direction class predictions.
cls_scores (list[Tensor]): Box scores for each scale level, cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is each is a 4D-tensor, the channel number is
...@@ -402,7 +415,7 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -402,7 +415,7 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
corresponding to each box corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g., img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding gt_bboxes_ignore (list[Tensor]): specify which bounding
boxes can be ignored when computing the loss. boxes can be ignored when computing the loss.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment