Commit adb17824 authored by xiangxu-0103's avatar xiangxu-0103 Committed by ZwwWayne
Browse files

[Fix]: fix semantic segmentation related bugs (#1909)

delete whitespace

update docs

remove unnecessary optional docs

update docs

add mmengine assertion

add docstring

fix mminstall

update mmengine version

fix

[Fix]: fix semantic segmentation related bugs (#1909)

fix semantic seg

fix lint

remove unused imports

fix

update pointnet2-s3dis config

update data_list according to scene_idxs

remove useless function

fix bug lack `eval_ann_info` during evaluation

fix bug

update doc

fix lint

update docs

Update det3d_dataset.py

update docstrings

update docs

fix lint

update docs

fix

fix

fix lint
parent b37dc416
......@@ -166,24 +166,25 @@ class _S3DISSegDataset(Seg3DDataset):
wrapper to concat all the provided data in different areas.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
data_root (str, optional): Path of dataset root, Defaults to None.
ann_file (str): Path of annotation file. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='points', instance_mask='', semantic_mask='').
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
METAINFO = {
'CLASSES':
......@@ -207,9 +208,9 @@ class _S3DISSegDataset(Seg3DDataset):
pts='points', img='', instance_mask='', semantic_mask=''),
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index=None,
scene_idxs=None,
test_mode=False,
ignore_index: Optional[int] = None,
scene_idxs: Optional[Union[np.ndarray, str]] = None,
test_mode: bool = False,
**kwargs) -> None:
super().__init__(
data_root=data_root,
......@@ -250,37 +251,40 @@ class S3DISSegDataset(_S3DISSegDataset):
data downloading.
Args:
data_root (str): Path of dataset root.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_files (list[str]): Path of several annotation files.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='points', instance_mask='', semantic_mask='').
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (list[np.ndarray] | list[str], optional): Precomputed index
to load data. For scenes with many points, we may sample it several
times. Defaults to None.
to load data. For scenes with many points, we may sample it
several times. Defaults to None.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
def __init__(self,
data_root: Optional[str] = None,
ann_files: str = '',
ann_files: List[str] = '',
metainfo: Optional[dict] = None,
data_prefix: dict = dict(
pts='points', img='', instance_mask='', semantic_mask=''),
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index=None,
scene_idxs=None,
test_mode=False,
ignore_index: Optional[int] = None,
scene_idxs: Optional[Union[List[np.ndarray],
List[str]]] = None,
test_mode: bool = False,
**kwargs) -> None:
# make sure that ann_files and scene_idxs have same length
......@@ -318,13 +322,12 @@ class S3DISSegDataset(_S3DISSegDataset):
# data_list and scene_idxs need to be concat
self.concat_data_list([dst.data_list for dst in datasets])
self.concat_scene_idxs([dst.scene_idxs for dst in datasets])
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
def concat_data_list(self, data_lists):
def concat_data_list(self, data_lists: List[List[dict]]) -> List[dict]:
"""Concat data_list from several datasets to form self.data_list.
Args:
......@@ -334,21 +337,6 @@ class S3DISSegDataset(_S3DISSegDataset):
data for data_list in data_lists for data in data_list
]
def concat_scene_idxs(self, scene_idxs):
"""Concat scene_idxs from several datasets to form self.scene_idxs.
Needs to manually add offset to scene_idxs[1, 2, ...].
Args:
scene_idxs (list[np.ndarray])
"""
self.scene_idxs = np.array([], dtype=np.int32)
offset = 0
for one_scene_idxs in scene_idxs:
self.scene_idxs = np.concatenate(
[self.scene_idxs, one_scene_idxs + offset]).astype(np.int32)
offset = np.unique(self.scene_idxs).max() + 1
@staticmethod
def _duplicate_to_list(x, num):
"""Repeat x `num` times to form a list."""
......
......@@ -26,13 +26,13 @@ class ScanNetDataset(Det3DDataset):
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict): Prefix for data. Defaults to
`dict(pts='points',
pts_isntance_mask='instance_mask',
pts_semantic_mask='semantic_mask')`.
dict(pts='points',
pts_isntance_mask='instance_mask',
pts_semantic_mask='semantic_mask').
pipeline (list[dict]): Pipeline used for data processing.
Defaults to None.
modality (dict): Modality to specify the sensor data used
as input. Defaults to None.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=False, use_lidar=True).
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
......@@ -41,8 +41,10 @@ class ScanNetDataset(Det3DDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool): Whether to filter empty GT.
Defaults to True.
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
......@@ -71,7 +73,7 @@ class ScanNetDataset(Det3DDataset):
box_type_3d: str = 'Depth',
filter_empty_gt: bool = True,
test_mode: bool = False,
**kwargs):
**kwargs) -> None:
# construct seg_label_mapping for semantic mask
seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
......@@ -128,8 +130,8 @@ class ScanNetDataset(Det3DDataset):
info (dict): Raw info dict.
Returns:
dict: Data information that will be passed to the data
preprocessing transforms. It includes the following keys:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
info['axis_align_matrix'] = self._get_axis_align_matrix(info)
info['pts_instance_mask_path'] = osp.join(
......@@ -146,13 +148,13 @@ class ScanNetDataset(Det3DDataset):
return info
def parse_ann_info(self, info: dict) -> dict:
"""Process the `instances` in data info to `ann_info`
"""Process the `instances` in data info to `ann_info`.
Args:
info (dict): Info dict.
Returns:
dict: Processed `ann_info`
dict: Processed `ann_info`.
"""
ann_info = super().parse_ann_info(info)
# empty gt
......@@ -181,24 +183,25 @@ class ScanNetSegDataset(Seg3DDataset):
for data downloading.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_file (str): Path of annotation file. Defaults to ''.
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='velodyne', img='', instance_mask='', semantic_mask='').
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
METAINFO = {
'CLASSES':
......@@ -242,9 +245,9 @@ class ScanNetSegDataset(Seg3DDataset):
pts='points', img='', instance_mask='', semantic_mask=''),
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index=None,
scene_idxs=None,
test_mode=False,
ignore_index: Optional[int] = None,
scene_idxs: Optional[Union[np.ndarray, str]] = None,
test_mode: bool = False,
**kwargs) -> None:
super().__init__(
data_root=data_root,
......@@ -315,10 +318,10 @@ class ScanNetInstanceSegDataset(Seg3DDataset):
pts='points', img='', instance_mask='', semantic_mask=''),
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False),
test_mode=False,
ignore_index=None,
scene_idxs=None,
file_client_args=dict(backend='disk'),
test_mode: bool = False,
ignore_index: Optional[int] = None,
scene_idxs: Optional[Union[np.ndarray, str]] = None,
file_client_args: dict = dict(backend='disk'),
**kwargs) -> None:
super().__init__(
data_root=data_root,
......
......@@ -16,24 +16,20 @@ class Seg3DDataset(BaseDataset):
This is the base dataset of ScanNet, S3DIS and SemanticKITTI dataset.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_file (str): Path of annotation file. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='velodyne', img='', instance_mask='', semantic_mask='').
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input, it usually has following keys.
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used
as input, it usually has following keys:
- use_camera: bool
- use_lidar: bool
Defaults to `dict(use_lidar=True, use_camera=False)`
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
......@@ -41,11 +37,13 @@ class Seg3DDataset(BaseDataset):
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
load_eval_anns (bool): Whether to load annotations
in test_mode, the annotation will be save in
`eval_ann_infos`, which can be use in Evaluator.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
load_eval_anns (bool): Whether to load annotations in test_mode,
the annotation will be save in `eval_ann_infos`, which can be used
in Evaluator. Defaults to True.
file_client_args (dict): Configuration of file client.
Defaults to `dict(backend='disk')`.
Defaults to dict(backend='disk').
"""
METAINFO = {
'CLASSES': None, # names of all classes data used for the task
......@@ -66,7 +64,7 @@ class Seg3DDataset(BaseDataset):
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index: Optional[int] = None,
scene_idxs: Optional[str] = None,
scene_idxs: Optional[Union[str, np.ndarray]] = None,
test_mode: bool = False,
load_eval_anns: bool = True,
file_client_args: dict = dict(backend='disk'),
......@@ -121,6 +119,7 @@ class Seg3DDataset(BaseDataset):
self.metainfo['seg_label_mapping'] = self.seg_label_mapping
self.scene_idxs = self.get_scene_idxs(scene_idxs)
self.data_list = [self.data_list[i] for i in self.scene_idxs]
# set group flag for the sampler
if not self.test_mode:
......@@ -141,10 +140,9 @@ class Seg3DDataset(BaseDataset):
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
Returns:
tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo
new classes in metainfo
"""
old_classes = self.METAINFO.get('CLASSES', None)
if (new_classes is not None and old_classes is not None
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union
import numpy as np
from mmdet3d.registry import DATASETS
from .seg3d_dataset import Seg3DDataset
......@@ -14,26 +16,28 @@ class SemanticKITTIDataset(Seg3DDataset):
for data downloading
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_file (str): Path of annotation file. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='points', img='', instance_mask='', semantic_mask='').
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input,
it usually has following keys:
- use_camera: bool
- use_lidar: bool
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
box_type_3d (str, optional): NO 3D box for this dataset.
You can choose any type
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
METAINFO = {
......@@ -55,9 +59,9 @@ class SemanticKITTIDataset(Seg3DDataset):
pts='points', img='', instance_mask='', semantic_mask=''),
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index=None,
scene_idxs=None,
test_mode=False,
ignore_index: Optional[int] = None,
scene_idxs: Optional[Union[str, np.ndarray]] = None,
test_mode: bool = False,
**kwargs) -> None:
super().__init__(
......
......@@ -24,13 +24,13 @@ class SUNRGBDDataset(Det3DDataset):
ann_file (str): Path of annotation file.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict, optiona;): Prefix for data. Defaults to
data_prefix (dict): Prefix for data. Defaults to
dict(pts='points',img='sunrgbd_trainval').
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to dict(use_camera=True, use_lidar=True).
default_cam_key (str, optional): The default camera name adopted.
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=True, use_lidar=True).
default_cam_key (str): The default camera name adopted.
Defaults to 'CAM0'.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
......@@ -40,9 +40,9 @@ class SUNRGBDDataset(Det3DDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
filter_empty_gt (bool): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
METAINFO = {
......@@ -58,11 +58,11 @@ class SUNRGBDDataset(Det3DDataset):
pts='points', img='sunrgbd_trainval/image'),
pipeline: List[Union[dict, Callable]] = [],
default_cam_key: str = 'CAM0',
modality=dict(use_camera=True, use_lidar=True),
modality: dict = dict(use_camera=True, use_lidar=True),
box_type_3d: str = 'Depth',
filter_empty_gt: bool = True,
test_mode: bool = False,
**kwargs):
**kwargs) -> None:
super().__init__(
data_root=data_root,
ann_file=ann_file,
......@@ -121,7 +121,7 @@ class SUNRGBDDataset(Det3DDataset):
return info
def parse_ann_info(self, info: dict) -> dict:
"""Process the `instances` in data info to `ann_info`
"""Process the `instances` in data info to `ann_info`.
Args:
info (dict): Info dict.
......
......@@ -18,9 +18,8 @@ class BatchSampler:
sample_list (list[dict]): List of samples.
name (str, optional): The category of samples. Defaults to None.
epoch (int, optional): Sampling epoch. Defaults to None.
shuffle (bool, optional): Whether to shuffle indices.
Defaults to False.
drop_reminder (bool, optional): Drop reminder. Defaults to False.
shuffle (bool): Whether to shuffle indices. Defaults to False.
drop_reminder (bool): Drop reminder. Defaults to False.
"""
def __init__(self,
......@@ -90,12 +89,11 @@ class DataBaseSampler(object):
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Defaults to None.
points_loader(dict, optional): Config of points loader. Defaults to
points_loader (dict): Config of points loader. Defaults to
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
"""
def __init__(
......@@ -219,9 +217,9 @@ class DataBaseSampler(object):
dict: Dict of sampled 'pseudo ground truths'.
- gt_labels_3d (np.ndarray): ground truths labels
of sampled objects.
of sampled objects.
- gt_bboxes_3d (:obj:`BaseInstance3DBoxes`):
sampled ground truth 3D bounding boxes
sampled ground truth 3D bounding boxes
- points (np.ndarray): sampled points
- group_ids (np.ndarray): ids of sampled ground truths
"""
......
......@@ -102,7 +102,7 @@ class Pack3DDetInputs(BaseTransform):
- points
- img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info of
- 'data_samples' (:obj:`Det3DDataSample`): The annotation info of
the sample.
"""
# augtest
......
......@@ -20,19 +20,17 @@ class LoadMultiViewImageFromFiles(BaseTransform):
Expects results['img_filename'] to be a list of filenames.
Args:
to_float32 (bool, optional): Whether to convert the img to float32.
to_float32 (bool): Whether to convert the img to float32.
Defaults to False.
color_type (str, optional): Color type of the file.
Defaults to 'unchanged'.
file_client_args (dict): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
num_views (int): num of view in a frame. Default to 5.
num_ref_frames (int): num of frame in loading. Default to -1.
test_mode (bool): Whether is test mode in loading. Default to False.
set_default_scale (bool): Whether to set default scale. Default to
True.
color_type (str): Color type of the file. Defaults to 'unchanged'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
num_views (int): Number of view in a frame. Defaults to 5.
num_ref_frames (int): Number of frame in loading. Defaults to -1.
test_mode (bool): Whether is test mode in loading. Defaults to False.
set_default_scale (bool): Whether to set default scale.
Defaults to True.
"""
def __init__(self,
......@@ -63,7 +61,7 @@ class LoadMultiViewImageFromFiles(BaseTransform):
Returns:
dict: The result dict containing the multi-view image data.
Added keys and values are described below.
Added keys and values are described below.
- filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays.
......@@ -210,7 +208,7 @@ class LoadMultiViewImageFromFiles(BaseTransform):
results['num_ref_frames'] = self.num_ref_frames
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(to_float32={self.to_float32}, '
......@@ -276,22 +274,17 @@ class LoadPointsFromMultiSweeps(BaseTransform):
This is usually used for nuScenes dataset to utilize previous sweeps.
Args:
sweeps_num (int, optional): Number of sweeps. Defaults to 10.
load_dim (int, optional): Dimension number of the loaded points.
Defaults to 5.
use_dim (list[int], optional): Which dimension to use.
Defaults to [0, 1, 2, 4].
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
sweeps_num (int): Number of sweeps. Defaults to 10.
load_dim (int): Dimension number of the loaded points. Defaults to 5.
use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4].
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
pad_empty_sweeps (bool): Whether to repeat keyframe when
sweeps is empty. Defaults to False.
remove_close (bool, optional): Whether to remove close points.
Defaults to False.
test_mode (bool, optional): If `test_mode=True`, it will not
randomly sample sweeps but select the nearest N frames.
Defaults to False.
remove_close (bool): Whether to remove close points. Defaults to False.
test_mode (bool): If `test_mode=True`, it will not randomly sample
sweeps but select the nearest N frames. Defaults to False.
"""
def __init__(self,
......@@ -336,11 +329,11 @@ class LoadPointsFromMultiSweeps(BaseTransform):
def _remove_close(self,
points: Union[np.ndarray, BasePoints],
radius: float = 1.0) -> Union[np.ndarray, BasePoints]:
"""Removes point too close within a certain radius from origin.
"""Remove point too close within a certain radius from origin.
Args:
points (np.ndarray | :obj:`BasePoints`): Sweep points.
radius (float, optional): Radius below which points are removed.
radius (float): Radius below which points are removed.
Defaults to 1.0.
Returns:
......@@ -366,10 +359,10 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Returns:
dict: The result dict containing the multi-sweep points data.
Updated key and value are described below.
Updated key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
cloud arrays.
cloud arrays.
"""
points = results['points']
points.tensor[:, 4] = 0
......@@ -414,7 +407,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
results['points'] = points
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'
......@@ -444,7 +437,7 @@ class PointSegClassMapping(BaseTransform):
Returns:
dict: The result dict containing the mapped category ids.
Updated key and value are described below.
Updated key and value are described below.
- pts_semantic_mask (np.ndarray): Mapped semantic masks.
"""
......@@ -465,7 +458,7 @@ class PointSegClassMapping(BaseTransform):
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
return repr_str
......@@ -490,7 +483,7 @@ class NormalizePointsColor(BaseTransform):
Returns:
dict: The result dict containing the normalized points.
Updated key and value are described below.
Updated key and value are described below.
- points (:obj:`BasePoints`): Points after color normalization.
"""
......@@ -505,7 +498,7 @@ class NormalizePointsColor(BaseTransform):
input_dict['points'] = points
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(color_mean={self.color_mean})'
......@@ -533,19 +526,15 @@ class LoadPointsFromFile(BaseTransform):
- 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points.
Defaults to 6.
use_dim (list[int] | int, optional): Which dimensions of the points
to use. Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
load_dim (int): The dimension of the loaded points. Defaults to 6.
use_dim (list[int] | int): Which dimensions of the points to use.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height.
Defaults to False.
use_color (bool, optional): Whether to use color features.
Defaults to False.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
shift_height (bool): Whether to use shifted height. Defaults to False.
use_color (bool): Whether to use color features. Defaults to False.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
"""
def __init__(
......@@ -602,7 +591,7 @@ class LoadPointsFromFile(BaseTransform):
Returns:
dict: The result dict containing the point clouds data.
Added key and value are described below.
Added key and value are described below.
- points (:obj:`BasePoints`): Point clouds data.
"""
......@@ -638,7 +627,7 @@ class LoadPointsFromFile(BaseTransform):
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ + '('
repr_str += f'shift_height={self.shift_height}, '
......@@ -688,7 +677,7 @@ class LoadAnnotations3D(LoadAnnotations):
- pts_instance_mask_path (str): Path of instance mask file.
Only when `with_mask_3d` is True.
- pts_semantic_mask_path (str): Path of semantic mask file.
Only when
Only when `with_seg_3d` is True.
Added Keys:
......@@ -713,33 +702,25 @@ class LoadAnnotations3D(LoadAnnotations):
Only when `with_seg_3d` is True.
Args:
with_bbox_3d (bool, optional): Whether to load 3D boxes.
Defaults to True.
with_label_3d (bool, optional): Whether to load 3D labels.
Defaults to True.
with_attr_label (bool, optional): Whether to load attribute label.
Defaults to False.
with_mask_3d (bool, optional): Whether to load 3D instance masks.
for points. Defaults to False.
with_seg_3d (bool, optional): Whether to load 3D semantic masks.
for points. Defaults to False.
with_bbox (bool, optional): Whether to load 2D boxes.
Defaults to False.
with_label (bool, optional): Whether to load 2D labels.
with_bbox_3d (bool): Whether to load 3D boxes. Defaults to True.
with_label_3d (bool): Whether to load 3D labels. Defaults to True.
with_attr_label (bool): Whether to load attribute label.
Defaults to False.
with_mask (bool, optional): Whether to load 2D instance masks.
with_mask_3d (bool): Whether to load 3D instance masks for points.
Defaults to False.
with_seg (bool, optional): Whether to load 2D semantic masks.
with_seg_3d (bool): Whether to load 3D semantic masks for points.
Defaults to False.
with_bbox_depth (bool, optional): Whether to load 2.5D boxes.
Defaults to False.
poly2mask (bool, optional): Whether to convert polygon annotations
to bitmasks. Defaults to True.
seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
Defaults to int64.
file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details.
with_bbox (bool): Whether to load 2D boxes. Defaults to False.
with_label (bool): Whether to load 2D labels. Defaults to False.
with_mask (bool): Whether to load 2D instance masks. Defaults to False.
with_seg (bool): Whether to load 2D semantic masks. Defaults to False.
with_bbox_depth (bool): Whether to load 2.5D boxes. Defaults to False.
poly2mask (bool): Whether to convert polygon annotations to bitmasks.
Defaults to True.
seg_3d_dtype (dtype): Dtype of 3D semantic masks. Defaults to int64.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
"""
def __init__(
......@@ -889,7 +870,8 @@ class LoadAnnotations3D(LoadAnnotations):
`ignore_flag`
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
results (dict): Result dict from :obj:`mmcv.BaseDataset`.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
......@@ -900,7 +882,7 @@ class LoadAnnotations3D(LoadAnnotations):
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``.
results (dict): Result dict from :obj :obj:`mmcv.BaseDataset`.
Returns:
dict: The dict contains loaded label annotations.
......@@ -933,7 +915,7 @@ class LoadAnnotations3D(LoadAnnotations):
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
indent_str = ' '
repr_str = self.__class__.__name__ + '(\n'
......
......@@ -19,18 +19,17 @@ class MultiScaleFlipAug3D(BaseTransform):
img_scale (tuple | list[tuple]): Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for
resizing.
flip (bool, optional): Whether apply flip augmentation.
Defaults to False.
flip_direction (str | list[str], optional): Flip augmentation
directions for images, options are "horizontal" and "vertical".
flip (bool): Whether apply flip augmentation. Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions
for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when ``flip == False``.
Defaults to 'horizontal'.
pcd_horizontal_flip (bool, optional): Whether to apply horizontal
flip augmentation to point cloud. Defaults to True.
pcd_horizontal_flip (bool): Whether to apply horizontal flip
augmentation to point cloud. Defaults to False.
Note that it works only when 'flip' is turned on.
pcd_vertical_flip (bool, optional): Whether to apply vertical flip
augmentation to point cloud. Defaults to True.
pcd_vertical_flip (bool): Whether to apply vertical flip
augmentation to point cloud. Defaults to False.
Note that it works only when 'flip' is turned on.
"""
......@@ -75,7 +74,7 @@ class MultiScaleFlipAug3D(BaseTransform):
Returns:
List[dict]: The list contains the data that is augmented with
different scales and flips.
different scales and flips.
"""
aug_data_list = []
......@@ -112,7 +111,7 @@ class MultiScaleFlipAug3D(BaseTransform):
return aug_data_list
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
......
......@@ -30,7 +30,7 @@ class RandomDropPointsColor(BaseTransform):
util/transform.py#L223>`_ for more details.
Args:
drop_ratio (float, optional): The probability of dropping point colors.
drop_ratio (float): The probability of dropping point colors.
Defaults to 0.2.
"""
......@@ -46,8 +46,8 @@ class RandomDropPointsColor(BaseTransform):
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after color dropping,
'points' key is updated in the result dict.
dict: Results after color dropping, 'points' key is updated
in the result dict.
"""
points = input_dict['points']
assert points.attribute_dims is not None and \
......@@ -64,7 +64,7 @@ class RandomDropPointsColor(BaseTransform):
points.color = points.color * 0.0
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(drop_ratio={self.drop_ratio})'
......@@ -108,8 +108,8 @@ class RandomFlip3D(RandomFlip):
in vertical direction. Defaults to 0.0.
flip_box3d (bool): Whether to flip bounding box. In most of the case,
the box should be fliped. In cam-based bev detection, this is set
to false, since the flip of 2D images does not influence the 3D
box. Default to True.
to False, since the flip of 2D images does not influence the 3D
box. Defaults to True.
"""
def __init__(self,
......@@ -150,12 +150,11 @@ class RandomFlip3D(RandomFlip):
Args:
input_dict (dict): Result dict from loading pipeline.
direction (str, optional): Flip direction.
Default: 'horizontal'.
direction (str): Flip direction. Defaults to 'horizontal'.
Returns:
dict: Flipped results, 'points', 'bbox3d_fields' keys are
updated in the result dict.
updated in the result dict.
"""
assert direction in ['horizontal', 'vertical']
if self.flip_box3d:
......@@ -210,8 +209,8 @@ class RandomFlip3D(RandomFlip):
Returns:
dict: Flipped results, 'flip', 'flip_direction',
'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added
into result dict.
'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added
into result dict.
"""
# flip 2D image and its annotations
if 'img' in input_dict:
......@@ -241,7 +240,7 @@ class RandomFlip3D(RandomFlip):
input_dict['transformation_3d_flow'].extend(['VF'])
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(sync_2d={self.sync_2d},'
......@@ -254,7 +253,7 @@ class RandomJitterPoints(BaseTransform):
"""Randomly jitter point coordinates.
Different from the global translation in ``GlobalRotScaleTrans``, here we
apply different noises to each point in a scene.
apply different noises to each point in a scene.
Args:
jitter_std (list[float]): The standard deviation of jittering noise.
......@@ -267,7 +266,7 @@ class RandomJitterPoints(BaseTransform):
Note:
This transform should only be used in point cloud segmentation tasks
because we don't transform ground-truth bboxes accordingly.
because we don't transform ground-truth bboxes accordingly.
For similar transform in detection task, please refer to `ObjectNoise`.
"""
......@@ -296,7 +295,7 @@ class RandomJitterPoints(BaseTransform):
Returns:
dict: Results after adding noise to each point,
'points' key is updated in the result dict.
'points' key is updated in the result dict.
"""
points = input_dict['points']
jitter_std = np.array(self.jitter_std, dtype=np.float32)
......@@ -309,7 +308,7 @@ class RandomJitterPoints(BaseTransform):
points.translate(jitter_noise)
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(jitter_std={self.jitter_std},'
......@@ -344,11 +343,11 @@ class ObjectSample(BaseTransform):
Args:
db_sampler (dict): Config dict of the database sampler.
sample_2d (bool): Whether to also paste 2D image patch to the images
sample_2d (bool): Whether to also paste 2D image patch to the images.
This should be true when applying multi-modality cut-and-paste.
Defaults to False.
use_ground_plane (bool): Whether to use ground plane to adjust the
3D labels.
3D labels. Defaults to False.
"""
def __init__(self,
......@@ -386,8 +385,8 @@ class ObjectSample(BaseTransform):
Returns:
dict: Results after object sampling augmentation,
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict.
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict.
"""
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']
......@@ -445,12 +444,12 @@ class ObjectSample(BaseTransform):
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'db_sampler={self.db_sampler},'
repr_str += f'(db_sampler={self.db_sampler},'
repr_str += f' sample_2d={self.sample_2d},'
repr_str += f' use_ground_plane={self.use_ground_plane}'
repr_str += f' use_ground_plane={self.use_ground_plane})'
return repr_str
......@@ -469,15 +468,15 @@ class ObjectNoise(BaseTransform):
- gt_bboxes_3d
Args:
translation_std (list[float], optional): Standard deviation of the
translation_std (list[float]): Standard deviation of the
distribution where translation noise are sampled from.
Defaults to [0.25, 0.25, 0.25].
global_rot_range (list[float], optional): Global rotation to the scene.
global_rot_range (list[float]): Global rotation to the scene.
Defaults to [0.0, 0.0].
rot_range (list[float], optional): Object rotation range.
rot_range (list[float]): Object rotation range.
Defaults to [-0.15707963267, 0.15707963267].
num_try (int, optional): Number of times to try if the noise applied is
invalid. Defaults to 100.
num_try (int): Number of times to try if the noise applied is invalid.
Defaults to 100.
"""
def __init__(self,
......@@ -498,7 +497,7 @@ class ObjectNoise(BaseTransform):
Returns:
dict: Results after adding noise to each object,
'points', 'gt_bboxes_3d' keys are updated in the result dict.
'points', 'gt_bboxes_3d' keys are updated in the result dict.
"""
gt_bboxes_3d = input_dict['gt_bboxes_3d']
points = input_dict['points']
......@@ -519,7 +518,7 @@ class ObjectNoise(BaseTransform):
input_dict['points'] = points.new_point(numpy_points)
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_try={self.num_try},'
......@@ -538,10 +537,10 @@ class GlobalAlignment(BaseTransform):
Note:
We do not record the applied rotation and translation as in
GlobalRotScaleTrans. Because usually, we do not need to reverse
the alignment step.
GlobalRotScaleTrans. Because usually, we do not need to reverse
the alignment step.
For example, ScanNet 3D detection task uses aligned ground-truth
bounding boxes for evaluation.
bounding boxes for evaluation.
"""
def __init__(self, rotation_axis: int) -> None:
......@@ -593,7 +592,7 @@ class GlobalAlignment(BaseTransform):
Returns:
dict: Results after global alignment, 'points' and keys in
input_dict['bbox3d_fields'] are updated in the result dict.
input_dict['bbox3d_fields'] are updated in the result dict.
"""
assert 'axis_align_matrix' in results, \
'axis_align_matrix is not provided in GlobalAlignment'
......@@ -610,7 +609,7 @@ class GlobalAlignment(BaseTransform):
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(rotation_axis={self.rotation_axis})'
......@@ -640,15 +639,15 @@ class GlobalRotScaleTrans(BaseTransform):
- pcd_scale_factor (np.float32)
Args:
rot_range (list[float], optional): Range of rotation angle.
rot_range (list[float]): Range of rotation angle.
Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
scale_ratio_range (list[float], optional): Range of scale ratio.
scale_ratio_range (list[float]): Range of scale ratio.
Defaults to [0.95, 1.05].
translation_std (list[float], optional): The standard deviation of
translation_std (list[float]): The standard deviation of
translation noise applied to a scene, which
is sampled from a gaussian distribution whose standard deviation
is set by ``translation_std``. Defaults to [0, 0, 0]
shift_height (bool, optional): Whether to shift height.
is set by ``translation_std``. Defaults to [0, 0, 0].
shift_height (bool): Whether to shift height.
(the fourth dimension of indoor points) when scaling.
Defaults to False.
"""
......@@ -689,8 +688,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
dict: Results after translation, 'points', 'pcd_trans'
and `gt_bboxes_3d` is updated
in the result dict.
and `gt_bboxes_3d` is updated in the result dict.
"""
translation_std = np.array(self.translation_std, dtype=np.float32)
trans_factor = np.random.normal(scale=translation_std, size=3).T
......@@ -708,8 +706,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
dict: Results after rotation, 'points', 'pcd_rotation'
and `gt_bboxes_3d` is updated
in the result dict.
and `gt_bboxes_3d` is updated in the result dict.
"""
rotation = self.rot_range
noise_rotation = np.random.uniform(rotation[0], rotation[1])
......@@ -735,8 +732,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
dict: Results after scaling, 'points' and
`gt_bboxes_3d` is updated
in the result dict.
`gt_bboxes_3d` is updated in the result dict.
"""
scale = input_dict['pcd_scale_factor']
points = input_dict['points']
......@@ -774,7 +770,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
dict: Results after scaling, 'points', 'pcd_rotation',
'pcd_scale_factor', 'pcd_trans' and `gt_bboxes_3d` is updated
'pcd_scale_factor', 'pcd_trans' and `gt_bboxes_3d` are updated
in the result dict.
"""
if 'transformation_3d_flow' not in input_dict:
......@@ -791,7 +787,7 @@ class GlobalRotScaleTrans(BaseTransform):
input_dict['transformation_3d_flow'].extend(['R', 'S', 'T'])
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(rot_range={self.rot_range},'
......@@ -813,7 +809,7 @@ class PointShuffle(BaseTransform):
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
idx = input_dict['points'].shuffle()
idx = idx.numpy()
......@@ -829,7 +825,7 @@ class PointShuffle(BaseTransform):
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
return self.__class__.__name__
......@@ -850,7 +846,7 @@ class ObjectRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range: List[float]):
def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict:
......@@ -861,7 +857,7 @@ class ObjectRangeFilter(BaseTransform):
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict.
keys are updated in the result dict.
"""
# Check points instance type and initialise bev_range
if isinstance(input_dict['gt_bboxes_3d'],
......@@ -887,7 +883,7 @@ class ObjectRangeFilter(BaseTransform):
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
......@@ -923,7 +919,7 @@ class PointsRangeFilter(BaseTransform):
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
points_mask = points.in_range_3d(self.pcd_range)
......@@ -942,7 +938,7 @@ class PointsRangeFilter(BaseTransform):
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
......@@ -977,7 +973,7 @@ class ObjectNameFilter(BaseTransform):
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict.
keys are updated in the result dict.
"""
gt_labels_3d = input_dict['gt_labels_3d']
gt_bboxes_mask = np.array([n in self.labels for n in gt_labels_3d],
......@@ -987,7 +983,7 @@ class ObjectNameFilter(BaseTransform):
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(classes={self.classes})'
......@@ -1017,8 +1013,8 @@ class PointSample(BaseTransform):
sample_range (float, optional): The range where to sample points.
If not None, the points with depth larger than `sample_range` are
prior to be sampled. Defaults to None.
replace (bool, optional): Whether the sampling is with or without
replacement. Defaults to False.
replace (bool): Whether the sampling is with or without replacement.
Defaults to False.
"""
def __init__(self,
......@@ -1046,10 +1042,9 @@ class PointSample(BaseTransform):
num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the
points will be sampled. Defaults to None.
replace (bool, optional): Sampling with or without replacement.
Defaults to False.
return_choices (bool, optional): Whether return choice.
replace (bool): Sampling with or without replacement.
Defaults to False.
return_choices (bool): Whether return choice. Defaults to False.
Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
......@@ -1089,7 +1084,7 @@ class PointSample(BaseTransform):
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
points, choices = self._points_random_sampling(
......@@ -1113,7 +1108,7 @@ class PointSample(BaseTransform):
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_points={self.num_points},'
......@@ -1149,7 +1144,7 @@ class IndoorPatchPointSample(BaseTransform):
Args:
num_points (int): Number of points to be sampled.
block_size (float, optional): Size of a block to sample points from.
block_size (float): Size of a block to sample points from.
Defaults to 1.5.
sample_rate (float, optional): Stride used in sliding patch generation.
This parameter is unused in `IndoorPatchPointSample` and thus has
......@@ -1159,24 +1154,24 @@ class IndoorPatchPointSample(BaseTransform):
segmentation task. This is set in PointSegClassMapping as neg_cls.
If not None, will be used as a patch selection criterion.
Defaults to None.
use_normalized_coord (bool, optional): Whether to use normalized xyz as
use_normalized_coord (bool): Whether to use normalized xyz as
additional features. Defaults to False.
num_try (int, optional): Number of times to try if the patch selected
is invalid. Defaults to 10.
enlarge_size (float, optional): Enlarge the sampled patch to
num_try (int): Number of times to try if the patch selected is invalid.
Defaults to 10.
enlarge_size (float): Enlarge the sampled patch to
[-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as
an augmentation. If None, set it as 0. Defaults to 0.2.
min_unique_num (int, optional): Minimum number of unique points
the sampled patch should contain. If None, use PointNet++'s method
to judge uniqueness. Defaults to None.
eps (float, optional): A value added to patch boundary to guarantee
eps (float): A value added to patch boundary to guarantee
points coverage. Defaults to 1e-2.
Note:
This transform should only be used in the training process of point
cloud segmentation tasks. For the sliding patch generation and
inference process in testing, please refer to the `slide_inference`
function of `EncoderDecoder3D` class.
cloud segmentation tasks. For the sliding patch generation and
inference process in testing, please refer to the `slide_inference`
function of `EncoderDecoder3D` class.
"""
def __init__(self,
......@@ -1356,7 +1351,7 @@ class IndoorPatchPointSample(BaseTransform):
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
......@@ -1386,7 +1381,7 @@ class IndoorPatchPointSample(BaseTransform):
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_points={self.num_points},'
......@@ -1405,7 +1400,7 @@ class BackgroundPointsFilter(BaseTransform):
"""Filter background points near the bounding box.
Args:
bbox_enlarge_range (tuple[float], float): Bbox enlarge range.
bbox_enlarge_range (tuple[float] | float): Bbox enlarge range.
"""
def __init__(self, bbox_enlarge_range: Union[Tuple[float], float]) -> None:
......@@ -1427,7 +1422,7 @@ class BackgroundPointsFilter(BaseTransform):
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
gt_bboxes_3d = input_dict['gt_bboxes_3d']
......@@ -1458,7 +1453,7 @@ class BackgroundPointsFilter(BaseTransform):
input_dict['pts_semantic_mask'] = pts_semantic_mask[valid_masks]
return input_dict
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(bbox_enlarge_range={self.bbox_enlarge_range.tolist()})'
......@@ -1473,9 +1468,10 @@ class VoxelBasedPointSampler(BaseTransform):
Args:
cur_sweep_cfg (dict): Config for sampling current points.
prev_sweep_cfg (dict): Config for sampling previous points.
prev_sweep_cfg (dict, optional): Config for sampling previous points.
Defaults to None.
time_dim (int): Index that indicate the time dimension
for input points.
for input points. Defaults to 3.
"""
def __init__(self,
......@@ -1502,7 +1498,7 @@ class VoxelBasedPointSampler(BaseTransform):
points (np.ndarray): Points subset to be sampled.
sampler (VoxelGenerator): Voxel based sampler for
each points subset.
point_dim (int): The dimension of each points
point_dim (int): The dimension of each points.
Returns:
np.ndarray: Sampled points.
......@@ -1529,7 +1525,7 @@ class VoxelBasedPointSampler(BaseTransform):
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results['points']
original_dim = points.shape[1]
......@@ -1589,7 +1585,7 @@ class VoxelBasedPointSampler(BaseTransform):
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
def _auto_indent(repr_str, indent):
......@@ -1625,7 +1621,7 @@ class AffineResize(BaseTransform):
img_scale (tuple): Images scales for resizing.
down_ratio (int): The down ratio of feature map.
Actually the arg should be >= 1.
bbox_clip_border (bool, optional): Whether clip the objects
bbox_clip_border (bool): Whether clip the objects
outside the border of the image. Defaults to True.
"""
......@@ -1646,7 +1642,7 @@ class AffineResize(BaseTransform):
Returns:
dict: Results after affine resize, 'affine_aug', 'trans_mat'
keys are added in the result dict.
keys are added in the result dict.
"""
# The results have gone through RandomShiftScale before AffineResize
if 'center' not in results:
......@@ -1803,7 +1799,7 @@ class AffineResize(BaseTransform):
ref_point3 = ref_point2 + np.array([-d[1], d[0]])
return ref_point3
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
......@@ -1838,7 +1834,7 @@ class RandomShiftScale(BaseTransform):
Returns:
dict: Results after random shift and scale, 'center', 'size'
and 'affine_aug' keys are added in the result dict.
and 'affine_aug' keys are added in the result dict.
"""
img = results['img']
......@@ -1863,7 +1859,7 @@ class RandomShiftScale(BaseTransform):
return results
def __repr__(self):
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(shift_scale={self.shift_scale}, '
......@@ -1874,7 +1870,7 @@ class RandomShiftScale(BaseTransform):
@TRANSFORMS.register_module()
class Resize3D(Resize):
def _resize_3d(self, results):
def _resize_3d(self, results: dict) -> None:
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
if 'centers_2d' in results:
......@@ -1888,6 +1884,7 @@ class Resize3D(Resize):
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'img_shape',
......@@ -1909,7 +1906,7 @@ class RandomResize3D(RandomResize):
and cam2img with ``results['scale']``.
"""
def _resize_3d(self, results):
def _resize_3d(self, results: dict) -> None:
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
if 'centers_2d' in results:
......@@ -1917,7 +1914,7 @@ class RandomResize3D(RandomResize):
results['cam2img'][0] *= np.array(results['scale_factor'][0])
results['cam2img'][1] *= np.array(results['scale_factor'][1])
def transform(self, results):
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes, masks, semantic
segmentation map. Compared to RandomResize, this function would further
check if scale is already set in results.
......@@ -1926,8 +1923,8 @@ class RandomResize3D(RandomResize):
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \
'keep_ratio' keys are added into result dict.
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
results['scale'] = self._random_scale()
......@@ -1989,14 +1986,14 @@ class RandomCrop3D(RandomCrop):
on cropped instance masks. Defaults to False.
bbox_clip_border (bool): Whether clip the objects outside
the border of the image. Defaults to True.
rel_offset_h (tuple): The cropping interval of image height. Default
rel_offset_h (tuple): The cropping interval of image height. Defaults
to (0., 1.).
rel_offset_w (tuple): The cropping interval of image width. Default
rel_offset_w (tuple): The cropping interval of image width. Defaults
to (0., 1.).
Note:
- If the image is smaller than the absolute crop size, return the
original image.
original image.
- The keys for bboxes, labels and masks must be aligned. That is,
``gt_bboxes`` corresponds to ``gt_labels`` and ``gt_masks``, and
``gt_bboxes_ignore`` corresponds to ``gt_labels_ignore`` and
......@@ -2005,14 +2002,16 @@ class RandomCrop3D(RandomCrop):
``allow_negative_crop`` is set to False, skip this image.
"""
def __init__(self,
crop_size,
crop_type='absolute',
allow_negative_crop=False,
recompute_bbox=False,
bbox_clip_border=True,
rel_offset_h=(0., 1.),
rel_offset_w=(0., 1.)):
def __init__(
self,
crop_size: tuple,
crop_type: str = 'absolute',
allow_negative_crop: bool = False,
recompute_bbox: bool = False,
bbox_clip_border: bool = True,
rel_offset_h: tuple = (0., 1.),
rel_offset_w: tuple = (0., 1.)
) -> None:
super().__init__(
crop_size=crop_size,
crop_type=crop_type,
......@@ -2024,7 +2023,10 @@ class RandomCrop3D(RandomCrop):
self.rel_offset_h = rel_offset_h
self.rel_offset_w = rel_offset_w
def _crop_data(self, results, crop_size, allow_negative_crop):
def _crop_data(self,
results: dict,
crop_size: tuple,
allow_negative_crop: bool = False) -> dict:
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
......@@ -2032,11 +2034,11 @@ class RandomCrop3D(RandomCrop):
results (dict): Result dict from loading pipeline.
crop_size (tuple): Expected absolute size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
contain any bbox area. Defaults to False.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
updated according to crop size.
"""
assert crop_size[0] > 0 and crop_size[1] > 0
for key in results.get('img_fields', ['img']):
......@@ -2119,7 +2121,7 @@ class RandomCrop3D(RandomCrop):
return results
def transform(self, results):
def transform(self, results: dict) -> dict:
"""Transform function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
......@@ -2128,7 +2130,7 @@ class RandomCrop3D(RandomCrop):
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
updated according to crop size.
"""
image_size = results['img'].shape[:2]
if 'crop_size' not in results:
......@@ -2139,7 +2141,8 @@ class RandomCrop3D(RandomCrop):
results = self._crop_data(results, crop_size, self.allow_negative_crop)
return results
def __repr__(self):
def __repr__(self) -> dict:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(crop_size={self.crop_size}, '
repr_str += f'crop_type={self.crop_type}, '
......@@ -2260,43 +2263,44 @@ class MultiViewWrapper(BaseTransform):
transforms (list[dict]): A list of dict specifying the transformations
for the monocular situation.
override_aug_config (bool): flag of whether to use the same aug config
for multiview image. Default to True.
for multiview image. Defaults to True.
process_fields (list): Desired keys that the transformations should
be conducted on. Default to ['img', 'cam2img', 'lidar2cam'],
be conducted on. Defaults to ['img', 'cam2img', 'lidar2cam'].
collected_keys (list): Collect information in transformation
like rotate angles, crop roi, and flip state. Default to
like rotate angles, crop roi, and flip state. Defaults to
['scale', 'scale_factor', 'crop',
'crop_offset', 'ori_shape',
'pad_shape', 'img_shape',
'pad_fixed_size', 'pad_size_divisor',
'flip', 'flip_direction', 'rotate'],
'flip', 'flip_direction', 'rotate'].
randomness_keys (list): The keys that related to the randomness
in transformation Default to
in transformation. Defaults to
['scale', 'scale_factor', 'crop_size', 'flip',
'flip_direction', 'photometric_param']
"""
def __init__(self,
transforms: dict,
override_aug_config: bool = True,
process_fields: list = ['img', 'cam2img', 'lidar2cam'],
collected_keys: list = [
'scale', 'scale_factor', 'crop', 'img_crop_offset',
'ori_shape', 'pad_shape', 'img_shape', 'pad_fixed_size',
'pad_size_divisor', 'flip', 'flip_direction', 'rotate'
],
randomness_keys: list = [
'scale', 'scale_factor', 'crop_size', 'img_crop_offset',
'flip', 'flip_direction', 'photometric_param'
]):
def __init__(
self,
transforms: dict,
override_aug_config: bool = True,
process_fields: list = ['img', 'cam2img', 'lidar2cam'],
collected_keys: list = [
'scale', 'scale_factor', 'crop', 'img_crop_offset', 'ori_shape',
'pad_shape', 'img_shape', 'pad_fixed_size', 'pad_size_divisor',
'flip', 'flip_direction', 'rotate'
],
randomness_keys: list = [
'scale', 'scale_factor', 'crop_size', 'img_crop_offset', 'flip',
'flip_direction', 'photometric_param'
]
) -> None:
self.transforms = Compose(transforms)
self.override_aug_config = override_aug_config
self.collected_keys = collected_keys
self.process_fields = process_fields
self.randomness_keys = randomness_keys
def transform(self, input_dict):
def transform(self, input_dict: dict) -> dict:
"""Transform function to do the transform for multiview image.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Callable, List, Optional, Union
from typing import Callable, List, Union
import numpy as np
......@@ -24,20 +24,20 @@ class WaymoDataset(KittiDataset):
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
data_prefix (dict): data prefix for point cloud and
camera data dict. Default to dict(
camera data dict. Defaults to dict(
pts='velodyne',
CAM_FRONT='image_0',
CAM_FRONT_RIGHT='image_1',
CAM_FRONT_LEFT='image_2',
CAM_SIDE_RIGHT='image_3',
CAM_SIDE_LEFT='image_4')
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used
as input. Defaults to dict(use_lidar=True).
default_cam_key (str, optional): Default camera key for lidar2img
default_cam_key (str): Default camera key for lidar2img
association. Defaults to 'CAM_FRONT'.
box_type_3d (str, optional): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes:
......@@ -45,22 +45,23 @@ class WaymoDataset(KittiDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
pcd_limit_range (list[float], optional): The range of point cloud
pcd_limit_range (list[float]): The range of point cloud
used to filter invalid predicted boxes.
Defaults to [-85, -85, -5, 85, 85, 5].
cam_sync_instances (bool, optional): If use the camera sync label
cam_sync_instances (bool): If use the camera sync label
supported from waymo version 1.3.1. Defaults to False.
load_interval (int, optional): load frame interval.
Defaults to 1.
task (str, optional): task for 3D detection (lidar, mono3d).
load_interval (int): load frame interval. Defaults to 1.
task (str): task for 3D detection (lidar, mono3d).
lidar: take all the ground trurh in the frame.
mono3d: take the groundtruth that can be seen in the cam.
Defaults to 'lidar'.
max_sweeps (int, optional): max sweep for each frame. Defaults to 0.
Defaults to 'lidar_det'.
max_sweeps (int): max sweep for each frame. Defaults to 0.
"""
METAINFO = {'CLASSES': ('Car', 'Pedestrian', 'Cyclist')}
......@@ -75,17 +76,17 @@ class WaymoDataset(KittiDataset):
CAM_SIDE_RIGHT='image_3',
CAM_SIDE_LEFT='image_4'),
pipeline: List[Union[dict, Callable]] = [],
modality: Optional[dict] = dict(use_lidar=True),
modality: dict = dict(use_lidar=True),
default_cam_key: str = 'CAM_FRONT',
box_type_3d: str = 'LiDAR',
filter_empty_gt: bool = True,
test_mode: bool = False,
pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0],
cam_sync_instances=False,
load_interval=1,
task='lidar_det',
max_sweeps=0,
**kwargs):
cam_sync_instances: bool = False,
load_interval: int = 1,
task: str = 'lidar_det',
max_sweeps: int = 0,
**kwargs) -> None:
self.load_interval = load_interval
# set loading mode for different task settings
self.cam_sync_instances = cam_sync_instances
......@@ -111,7 +112,7 @@ class WaymoDataset(KittiDataset):
**kwargs)
def parse_ann_info(self, info: dict) -> dict:
"""Get annotation info according to the given index.
"""Process the `instances` in data info to `ann_info`.
Args:
info (dict): Data information of single data sample.
......@@ -120,12 +121,12 @@ class WaymoDataset(KittiDataset):
dict: annotation information consists of the following keys:
- bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes.
3D ground truth bboxes.
- bbox_labels_3d (np.ndarray): Labels of ground truths.
- gt_bboxes (np.ndarray): 2D ground truth bboxes.
- gt_labels (np.ndarray): Labels of ground truths.
- difficulty (int): Difficulty defined by KITTI.
0, 1, 2 represent xxxxx respectively.
0, 1, 2 represent xxxxx respectively.
"""
ann_info = Det3DDataset.parse_ann_info(self, info)
if ann_info is None:
......
......@@ -41,19 +41,20 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
Args:
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float, optional): Ratio of dropout layer. Default: 0.5.
conv_cfg (dict, optional): Config of conv layers.
Default: dict(type='Conv1d').
norm_cfg (dict, optional): Config of norm layers.
Default: dict(type='BN1d').
act_cfg (dict, optional): Config of activation layers.
Default: dict(type='ReLU').
loss_decode (dict, optional): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int, optional): The label index to be ignored.
dropout_ratio (float): Ratio of dropout layer. Defaults to 0.5.
conv_cfg (dict): Config of conv layers.
Defaults to dict(type='Conv1d').
norm_cfg (dict): Config of norm layers.
Defaults to dict(type='BN1d').
act_cfg (dict): Config of activation layers.
Defaults to dict(type='ReLU').
loss_decode (dict): Config of decode loss.
Defaults to dict(type='CrossEntropyLoss').
ignore_index (int): The label index to be ignored.
When using masked BCE loss, ignore_index should be set to None.
Default: 255.
Defaults to 255.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
......@@ -105,8 +106,8 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
output = self.conv_seg(feat)
return output
def loss(self, inputs: List[Tensor],
batch_data_samples: SampleList) -> dict:
def loss(self, inputs: List[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Forward function for training.
Args:
......
......@@ -140,7 +140,8 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
def postprocess_result(self, seg_pred_list: List[dict],
batch_img_metas: List[dict]) -> list:
""" Convert results list to `Det3DDataSample`.
"""Convert results list to `Det3DDataSample`.
Args:
seg_logits_list (List[dict]): List of segmentation results,
seg_logits from model of each input point clouds sample.
......@@ -157,7 +158,8 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
for i in range(len(seg_pred_list)):
img_meta = batch_img_metas[i]
seg_pred = seg_pred_list[i]
prediction = Det3DDataSample(**{'metainfo': img_meta})
prediction = Det3DDataSample(**{'metainfo': img_meta.metainfo})
prediction.set_data({'eval_ann_info': img_meta.eval_ann_info})
prediction.set_data(
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
predictions.append(prediction)
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from typing import List, Tuple
import numpy as np
import torch
......@@ -40,10 +40,10 @@ class EncoderDecoder3D(Base3DSegmentor):
.. code:: text
predict(): inference() -> postprocess_result()
inference(): whole_inference()/slide_inference()
whole_inference()/slide_inference(): encoder_decoder()
encoder_decoder(): extract_feat() -> decode_head.predict()
predict(): inference() -> postprocess_result()
inference(): whole_inference()/slide_inference()
whole_inference()/slide_inference(): encoder_decoder()
encoder_decoder(): extract_feat() -> decode_head.predict()
4 The ``_forward`` method is used to output the tensor by running the model,
which includes two steps: (1) Extracts features to obtain the feature maps
......@@ -51,7 +51,7 @@ class EncoderDecoder3D(Base3DSegmentor):
.. code:: text
_forward(): extract_feat() -> _decode_head.forward()
_forward(): extract_feat() -> _decode_head.forward()
Args:
......@@ -65,10 +65,10 @@ class EncoderDecoder3D(Base3DSegmentor):
loass. Defaults to None.
train_cfg (OptConfigType): The config for training. Defaults to None.
test_cfg (OptConfigType): The config for testing. Defaults to None.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
data_preprocessor (OptConfigType): The pre-process config of
:class:`BaseDataPreprocessor`. Defaults to None.
init_cfg (OptMultiConfig): The weight initialized config for
:class:`BaseModule`. Defaults to None.
""" # noqa: E501
def __init__(self,
......@@ -80,7 +80,7 @@ class EncoderDecoder3D(Base3DSegmentor):
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
init_cfg: OptMultiConfig = None) -> None:
super(EncoderDecoder3D, self).__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.backbone = MODELS.build(backbone)
......@@ -122,15 +122,15 @@ class EncoderDecoder3D(Base3DSegmentor):
else:
self.loss_regularization = MODELS.build(loss_regularization)
def extract_feat(self, batch_inputs) -> List[Tensor]:
def extract_feat(self, batch_inputs: Tensor) -> Tensor:
"""Extract features from points."""
x = self.backbone(batch_inputs)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, batch_inputs: torch.Tensor,
batch_input_metas: List[dict]) -> List[Tensor]:
def encode_decode(self, batch_inputs: Tensor,
batch_input_metas: List[dict]) -> Tensor:
"""Encode points with backbone and decode into a semantic segmentation
map of the same size as input.
......@@ -178,7 +178,7 @@ class EncoderDecoder3D(Base3DSegmentor):
return losses
def _loss_regularization_forward_train(self):
def _loss_regularization_forward_train(self) -> dict:
"""Calculate regularization loss for model weight in training."""
losses = dict()
if isinstance(self.loss_regularization, nn.ModuleList):
......@@ -203,7 +203,7 @@ class EncoderDecoder3D(Base3DSegmentor):
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
(B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`]): The det3d
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
......@@ -213,7 +213,8 @@ class EncoderDecoder3D(Base3DSegmentor):
"""
# extract features using backbone
x = self.extract_feat(batch_inputs_dict)
points = torch.stack(batch_inputs_dict['points'])
x = self.extract_feat(points)
losses = dict()
......@@ -236,7 +237,7 @@ class EncoderDecoder3D(Base3DSegmentor):
patch_center: Tensor,
coord_max: Tensor,
feats: Tensor,
use_normalized_coord: bool = False):
use_normalized_coord: bool = False) -> Tensor:
"""Generating model input.
Generate input by subtracting patch center and adding additional
......@@ -273,7 +274,7 @@ class EncoderDecoder3D(Base3DSegmentor):
block_size: float,
sample_rate: float = 0.5,
use_normalized_coord: bool = False,
eps: float = 1e-3):
eps: float = 1e-3) -> Tuple[Tensor, Tensor]:
"""Sampling points in a sliding window fashion.
First sample patches to cover all the input points.
......@@ -291,12 +292,12 @@ class EncoderDecoder3D(Base3DSegmentor):
points coverage. Defaults to 1e-3.
Returns:
np.ndarray | np.ndarray:
tuple:
- patch_points (torch.Tensor): Points of different patches of
shape [K, N, 3+C].
shape [K, N, 3+C].
- patch_idxs (torch.Tensor): Index of each point in
`patch_points`, of shape [K, N].
`patch_points`, of shape [K, N].
"""
device = points.device
# we assume the first three dims are points' 3D coordinates
......@@ -372,7 +373,7 @@ class EncoderDecoder3D(Base3DSegmentor):
return patch_points, patch_idxs
def slide_inference(self, point: Tensor, img_meta: List[dict],
rescale: bool):
rescale: bool) -> Tensor:
"""Inference by sliding-window with overlap.
Args:
......@@ -417,14 +418,14 @@ class EncoderDecoder3D(Base3DSegmentor):
return preds.transpose(0, 1) # to [num_classes, K*N]
def whole_inference(self, points: Tensor, input_metas: List[dict],
rescale: bool):
rescale: bool) -> Tensor:
"""Inference with full scene (one forward pass without sliding)."""
seg_logit = self.encode_decode(points, input_metas)
# TODO: if rescale and voxelization segmentor
return seg_logit
def inference(self, points: Tensor, input_metas: List[dict],
rescale: bool):
rescale: bool) -> Tensor:
"""Inference with slide/whole style.
Args:
......@@ -489,7 +490,7 @@ class EncoderDecoder3D(Base3DSegmentor):
seg_map = seg_map.cpu()
seg_pred_list.append(seg_map)
return self.postprocess_result(seg_pred_list, batch_input_metas)
return self.postprocess_result(seg_pred_list, batch_data_samples)
def _forward(self,
batch_inputs_dict: dict,
......@@ -502,7 +503,7 @@ class EncoderDecoder3D(Base3DSegmentor):
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
(B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
......@@ -510,7 +511,8 @@ class EncoderDecoder3D(Base3DSegmentor):
Returns:
Tensor: Forward output of model without any post-processes.
"""
x = self.extract_feat(batch_inputs_dict)
points = torch.stack(batch_inputs_dict['points'])
x = self.extract_feat(points)
return self.decode_head.forward(x)
def aug_test(self, batch_inputs, batch_img_metas):
......
......@@ -5,6 +5,15 @@ short_version = __version__
def parse_version_info(version_str):
"""Parse a version string into a tuple.
Args:
version_str (str): The version string.
Returns:
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
(1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
"""
version_info = []
for x in version_str.split('.'):
if x.isdigit():
......
mmcv-full>=2.0.0rc0,<2.1.0
mmcv>=2.0.0rc0,<2.1.0
mmdet>=3.0.0rc0,<3.1.0
mmengine>=0.1.0,<1.0.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment