Commit 406ce50b authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'docstring_tai' into 'master'

add dataset docstrings and refine core docstrings

See merge request open-mmlab/mmdet.3d!147
parents 66245956 80680274
...@@ -76,7 +76,7 @@ class Anchor3DRangeGenerator(object): ...@@ -76,7 +76,7 @@ class Anchor3DRangeGenerator(object):
@property @property
def num_levels(self): def num_levels(self):
"""int: Number of feature levels that the generator is applied.""" """int: Number of feature levels that the generator is applied to."""
return len(self.scales) return len(self.scales)
def grid_anchors(self, featmap_sizes, device='cuda'): def grid_anchors(self, featmap_sizes, device='cuda'):
...@@ -220,14 +220,14 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator): ...@@ -220,14 +220,14 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator):
Note: Note:
The `align` means that the anchor's center is aligned with the voxel grid, The `align` means that the anchor's center is aligned with the voxel grid,
which is also the feature grid. The previous implementation of which is also the feature grid. The previous implementation of
`Anchor3DRangeGenerator` do not generate the anchors' center according `Anchor3DRangeGenerator` does not generate the anchors' center according
to the voxel grid. Rather, it generates the center by uniformly to the voxel grid. Rather, it generates the center by uniformly
distributing the anchors inside the minimum and maximum anchor ranges distributing the anchors inside the minimum and maximum anchor ranges
according to the feature map sizes. according to the feature map sizes.
However, this makes the anchors center does not match the feature grid. However, this makes the anchors center does not match the feature grid.
The AlignedAnchor3DRangeGenerator add + 1 when using the feature map sizes The AlignedAnchor3DRangeGenerator add + 1 when using the feature map sizes
to obtain the corners of the voxel grid. Then it shifts the coordinates to to obtain the corners of the voxel grid. Then it shifts the coordinates to
the center of voxel grid of use the left up corner to distribute anchors. the center of voxel grid and use the left up corner to distribute anchors.
Args: Args:
anchor_corner (bool): Whether to align with the corner of the voxel anchor_corner (bool): Whether to align with the corner of the voxel
......
...@@ -9,7 +9,7 @@ class BboxOverlapsNearest3D(object): ...@@ -9,7 +9,7 @@ class BboxOverlapsNearest3D(object):
Note: Note:
This IoU calculator first finds the nearest 2D boxes in bird eye view This IoU calculator first finds the nearest 2D boxes in bird eye view
(BEV), and then calculate the 2D IoU using :meth:`bbox_overlaps`. (BEV), and then calculates the 2D IoU using :meth:`bbox_overlaps`.
Args: Args:
coordinate (str): 'camera', 'lidar', or 'depth' coordinate system. coordinate (str): 'camera', 'lidar', or 'depth' coordinate system.
...@@ -140,7 +140,7 @@ def bbox_overlaps_3d(bboxes1, bboxes2, mode='iou', coordinate='camera'): ...@@ -140,7 +140,7 @@ def bbox_overlaps_3d(bboxes1, bboxes2, mode='iou', coordinate='camera'):
"""Calculate 3D IoU using cuda implementation. """Calculate 3D IoU using cuda implementation.
Note: Note:
This function calculate the IoU of 3D boxes based on their volumes. This function calculates the IoU of 3D boxes based on their volumes.
IoU calculator :class:`BboxOverlaps3D` uses this function to IoU calculator :class:`BboxOverlaps3D` uses this function to
calculate the actual IoUs of boxes. calculate the actual IoUs of boxes.
......
...@@ -15,7 +15,7 @@ class BaseInstance3DBoxes(object): ...@@ -15,7 +15,7 @@ class BaseInstance3DBoxes(object):
Args: Args:
tensor (torch.Tensor | np.ndarray | list): a N x box_dim matrix. tensor (torch.Tensor | np.ndarray | list): a N x box_dim matrix.
box_dim (int): Number of the dimension of a box box_dim (int): Number of the dimension of a box.
Each row is (x, y, z, x_size, y_size, z_size, yaw). Each row is (x, y, z, x_size, y_size, z_size, yaw).
Default to 7. Default to 7.
with_yaw (bool): Whether the box is with yaw rotation. with_yaw (bool): Whether the box is with yaw rotation.
...@@ -79,7 +79,7 @@ class BaseInstance3DBoxes(object): ...@@ -79,7 +79,7 @@ class BaseInstance3DBoxes(object):
@property @property
def yaw(self): def yaw(self):
"""Obtain the rotation of all the boxes. """Obtain the rotations of all the boxes.
Returns: Returns:
torch.Tensor: A vector with yaw of each box. torch.Tensor: A vector with yaw of each box.
...@@ -118,7 +118,7 @@ class BaseInstance3DBoxes(object): ...@@ -118,7 +118,7 @@ class BaseInstance3DBoxes(object):
"""Calculate the center of all the boxes. """Calculate the center of all the boxes.
Note: Note:
In the MMDetection.3D's convention, the bottom center is In the MMDetection3D's convention, the bottom center is
usually taken as the default center. usually taken as the default center.
The relative position of the centers in different kinds of The relative position of the centers in different kinds of
...@@ -161,11 +161,11 @@ class BaseInstance3DBoxes(object): ...@@ -161,11 +161,11 @@ class BaseInstance3DBoxes(object):
@abstractmethod @abstractmethod
def rotate(self, angles, axis=0): def rotate(self, angles, axis=0):
"""Calculate whether the points is in any of the boxes. """Calculate whether the points are in any of the boxes.
Args: Args:
angles (float): Rotation angles angles (float): Rotation angles.
axis (int): The axis to rotate the boxes axis (int): The axis to rotate the boxes.
""" """
pass pass
...@@ -175,7 +175,7 @@ class BaseInstance3DBoxes(object): ...@@ -175,7 +175,7 @@ class BaseInstance3DBoxes(object):
pass pass
def translate(self, trans_vector): def translate(self, trans_vector):
"""Calculate whether the points is in any of the boxes. """Calculate whether the points are in any of the boxes.
Args: Args:
trans_vector (torch.Tensor): Translation vector of size 1x3. trans_vector (torch.Tensor): Translation vector of size 1x3.
...@@ -188,7 +188,7 @@ class BaseInstance3DBoxes(object): ...@@ -188,7 +188,7 @@ class BaseInstance3DBoxes(object):
"""Check whether the boxes are in the given range. """Check whether the boxes are in the given range.
Args: Args:
box_range (list | torch.Tensor): the range of box box_range (list | torch.Tensor): The range of box
(x_min, y_min, z_min, x_max, y_max, z_max) (x_min, y_min, z_min, x_max, y_max, z_max)
Note: Note:
...@@ -217,7 +217,7 @@ class BaseInstance3DBoxes(object): ...@@ -217,7 +217,7 @@ class BaseInstance3DBoxes(object):
in order of (x_min, y_min, x_max, y_max). in order of (x_min, y_min, x_max, y_max).
Returns: Returns:
torch.Tensor: Indicating whether each box is inside torch.Tensor: Indicating whether each box is inside \
the reference range. the reference range.
""" """
pass pass
...@@ -227,7 +227,7 @@ class BaseInstance3DBoxes(object): ...@@ -227,7 +227,7 @@ class BaseInstance3DBoxes(object):
"""Convert self to `dst` mode. """Convert self to `dst` mode.
Args: Args:
dst (:obj:`BoxMode`): The target Box mode dst (:obj:`BoxMode`): The target Box mode.
rt_mat (np.ndarray | torch.Tensor): The rotation and translation rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None. matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates The conversion from `src` coordinates to `dst` coordinates
...@@ -317,7 +317,7 @@ class BaseInstance3DBoxes(object): ...@@ -317,7 +317,7 @@ class BaseInstance3DBoxes(object):
@classmethod @classmethod
def cat(cls, boxes_list): def cat(cls, boxes_list):
"""Concatenates a list of Boxes into a single Boxes. """Concatenate a list of Boxes into a single Boxes.
Args: Args:
boxes_list (list[:obj:`BaseInstances3DBoxes`]): List of boxes. boxes_list (list[:obj:`BaseInstances3DBoxes`]): List of boxes.
...@@ -345,7 +345,7 @@ class BaseInstance3DBoxes(object): ...@@ -345,7 +345,7 @@ class BaseInstance3DBoxes(object):
device (str | :obj:`torch.device`): The name of the device. device (str | :obj:`torch.device`): The name of the device.
Returns: Returns:
:obj:`BaseInstance3DBoxes`: A new boxes object in the :obj:`BaseInstance3DBoxes`: A new boxes object on the \
specific device. specific device.
""" """
original_type = type(self) original_type = type(self)
...@@ -367,7 +367,7 @@ class BaseInstance3DBoxes(object): ...@@ -367,7 +367,7 @@ class BaseInstance3DBoxes(object):
@property @property
def device(self): def device(self):
"""str: The device of the boxes are in.""" """str: The device of the boxes are on."""
return self.tensor.device return self.tensor.device
def __iter__(self): def __iter__(self):
...@@ -383,7 +383,7 @@ class BaseInstance3DBoxes(object): ...@@ -383,7 +383,7 @@ class BaseInstance3DBoxes(object):
"""Calculate height overlaps of two boxes. """Calculate height overlaps of two boxes.
Note: Note:
This function calculate the height overlaps between boxes1 and This function calculates the height overlaps between boxes1 and
boxes2, boxes1 and boxes2 should be in the same type. boxes2, boxes1 and boxes2 should be in the same type.
Args: Args:
...@@ -415,8 +415,8 @@ class BaseInstance3DBoxes(object): ...@@ -415,8 +415,8 @@ class BaseInstance3DBoxes(object):
"""Calculate 3D overlaps of two boxes. """Calculate 3D overlaps of two boxes.
Note: Note:
This function calculate the overlaps between boxes1 and boxes2, This function calculates the overlaps between boxes1 and boxes2,
boxes1 and boxes2 are not necessarily to be in the same type. boxes1 and boxes2 are not necessarily in the same type.
Args: Args:
boxes1 (:obj:`BaseInstanceBoxes`): Boxes 1 contain N boxes. boxes1 (:obj:`BaseInstanceBoxes`): Boxes 1 contain N boxes.
...@@ -470,12 +470,11 @@ class BaseInstance3DBoxes(object): ...@@ -470,12 +470,11 @@ class BaseInstance3DBoxes(object):
def new_box(self, data): def new_box(self, data):
"""Create a new box object with data. """Create a new box object with data.
The new box and its tensor has the similar properties The new box and its tensor has the similar properties \
as self and self.tensor, respectively. as self and self.tensor, respectively.
Args: Args:
data (torch.Tensor | numpy.array | list): Data which the data (torch.Tensor | numpy.array | list): Data to be copied.
returned Tensor copies.
Returns: Returns:
:obj:`BaseInstance3DBoxes`: A new bbox with data and other \ :obj:`BaseInstance3DBoxes`: A new bbox with data and other \
......
...@@ -67,8 +67,8 @@ class Box3DMode(IntEnum): ...@@ -67,8 +67,8 @@ class Box3DMode(IntEnum):
box (tuple | list | np.dnarray | box (tuple | list | np.dnarray |
torch.Tensor | BaseInstance3DBoxes): torch.Tensor | BaseInstance3DBoxes):
Can be a k-tuple, k-list or an Nxk array/tensor, where k = 7. Can be a k-tuple, k-list or an Nxk array/tensor, where k = 7.
src (BoxMode): The src Box mode. src (:obj:`BoxMode`): The src Box mode.
dst (BoxMode): The target Box mode. dst (:obj:`BoxMode`): The target Box mode.
rt_mat (np.dnarray | torch.Tensor): The rotation and translation rt_mat (np.dnarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None. matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates The conversion from `src` coordinates to `dst` coordinates
......
...@@ -36,7 +36,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -36,7 +36,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
@property @property
def height(self): def height(self):
"""Obtain the height of all the boxes. """Obtain the heights of all the boxes.
Returns: Returns:
torch.Tensor: A vector with height of each box. torch.Tensor: A vector with height of each box.
...@@ -125,7 +125,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -125,7 +125,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate the 2D bounding boxes in BEV with rotation. """Calculate the 2D bounding boxes in BEV with rotation.
Returns: Returns:
torch.Tensor: A n x 5 tensor of 2D BEV box of each box. torch.Tensor: A n x 5 tensor of 2D BEV box of each box. \
The box is in XYWHR format. The box is in XYWHR format.
""" """
return self.tensor[:, [0, 2, 3, 5, 6]] return self.tensor[:, [0, 2, 3, 5, 6]]
...@@ -231,7 +231,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -231,7 +231,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
polygon, we try to reduce the burden for simpler cases. polygon, we try to reduce the burden for simpler cases.
Returns: Returns:
torch.Tensor: Indicating whether each box is inside torch.Tensor: Indicating whether each box is inside \
the reference range. the reference range.
""" """
in_range_flags = ((self.tensor[:, 0] > box_range[0]) in_range_flags = ((self.tensor[:, 0] > box_range[0])
...@@ -245,8 +245,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -245,8 +245,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate height overlaps of two boxes. """Calculate height overlaps of two boxes.
Note: Note:
This function calculate the height overlaps between boxes1 and This function calculates the height overlaps between boxes1 and
boxes2, boxes1 and boxes2 should be in the same type. boxes2, where boxes1 and boxes2 should be in the same type.
Args: Args:
boxes1 (:obj:`BaseInstanceBoxes`): Boxes 1 contain N boxes. boxes1 (:obj:`BaseInstanceBoxes`): Boxes 1 contain N boxes.
......
...@@ -93,7 +93,7 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -93,7 +93,7 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate the 2D bounding boxes in BEV with rotation. """Calculate the 2D bounding boxes in BEV with rotation.
Returns: Returns:
torch.Tensor: A n x 5 tensor of 2D BEV box of each box. torch.Tensor: A n x 5 tensor of 2D BEV box of each box. \
The box is in XYWHR format. The box is in XYWHR format.
""" """
return self.tensor[:, [0, 1, 3, 4, 6]] return self.tensor[:, [0, 1, 3, 4, 6]]
...@@ -241,11 +241,12 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -241,11 +241,12 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
"""Find points that are in boxes (CUDA). """Find points that are in boxes (CUDA).
Args: Args:
points (torch.Tensor): [1, M, 3] or [M, 3], [x, y, z] points (torch.Tensor): Points in shape [1, M, 3] or [M, 3], \
in LiDAR coordinate. 3 dimensions are [x, y, z] in LiDAR coordinate.
Returns: Returns:
torch.Tensor: The box index of each point in, shape is (B, M, T). torch.Tensor: The index of boxes each point lies in with shape \
of (B, M, T).
""" """
from .box_3d_mode import Box3DMode from .box_3d_mode import Box3DMode
......
...@@ -26,7 +26,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -26,7 +26,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Attributes: Attributes:
tensor (torch.Tensor): Float matrix of N x box_dim. tensor (torch.Tensor): Float matrix of N x box_dim.
box_dim (int): Integer indicates the dimension of a box box_dim (int): Integer indicating the dimension of a box.
Each row is (x, y, z, x_size, y_size, z_size, yaw, ...). Each row is (x, y, z, x_size, y_size, z_size, yaw, ...).
with_yaw (bool): If True, the value of yaw will be set to 0 as minmax with_yaw (bool): If True, the value of yaw will be set to 0 as minmax
boxes. boxes.
...@@ -93,7 +93,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -93,7 +93,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate the 2D bounding boxes in BEV with rotation. """Calculate the 2D bounding boxes in BEV with rotation.
Returns: Returns:
torch.Tensor: A nx5 tensor of 2D BEV box of each box. \ torch.Tensor: A n x 5 tensor of 2D BEV box of each box. \
The box is in XYWHR format. The box is in XYWHR format.
""" """
return self.tensor[:, [0, 1, 3, 4, 6]] return self.tensor[:, [0, 1, 3, 4, 6]]
...@@ -201,11 +201,9 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -201,11 +201,9 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
In the original implementation of SECOND, checking whether In the original implementation of SECOND, checking whether
a box in the range checks whether the points are in a convex a box in the range checks whether the points are in a convex
polygon, we try to reduce the burdun for simpler cases. polygon, we try to reduce the burdun for simpler cases.
TODO: check whether this will effect the performance
Returns: Returns:
torch.Tensor: Indicating whether each box is inside \ torch.Tensor: Whether each box is inside the reference range.
the reference range.
""" """
in_range_flags = ((self.tensor[:, 0] > box_range[0]) in_range_flags = ((self.tensor[:, 0] > box_range[0])
& (self.tensor[:, 1] > box_range[1]) & (self.tensor[:, 1] > box_range[1])
...@@ -236,7 +234,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -236,7 +234,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
"""Enlarge the length, width and height boxes. """Enlarge the length, width and height boxes.
Args: Args:
extra_width (float | torch.Tensor): extra width to enlarge the box extra_width (float | torch.Tensor): Extra width to enlarge the box.
Returns: Returns:
:obj:`LiDARInstance3DBoxes`: Enlarged boxes. :obj:`LiDARInstance3DBoxes`: Enlarged boxes.
...@@ -251,7 +249,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -251,7 +249,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
"""Find the box which the points are in. """Find the box which the points are in.
Args: Args:
points (torch.Tensor): Points in shape Nx3. points (torch.Tensor): Points in shape (N, 3).
Returns: Returns:
torch.Tensor: The index of box where each point are in. torch.Tensor: The index of box where each point are in.
......
...@@ -86,7 +86,7 @@ def get_box_type(box_type): ...@@ -86,7 +86,7 @@ def get_box_type(box_type):
"""Get the type and mode of box structure. """Get the type and mode of box structure.
Args: Args:
box_type (str): Indicate the type of box structure. box_type (str): The type of box structure.
The valid value are "LiDAR", "Camera", or "Depth". The valid value are "LiDAR", "Camera", or "Depth".
Returns: Returns:
......
...@@ -24,14 +24,15 @@ def bbox3d_mapping_back(bboxes, scale_factor, flip_horizontal, flip_vertical): ...@@ -24,14 +24,15 @@ def bbox3d_mapping_back(bboxes, scale_factor, flip_horizontal, flip_vertical):
def bbox3d2roi(bbox_list): def bbox3d2roi(bbox_list):
"""Convert a list of bboxes to roi format. """Convert a list of bounding boxes to roi format.
Args: Args:
bbox_list (list[torch.Tensor]): A list of bboxes bbox_list (list[torch.Tensor]): A list of bounding boxes
corresponding to a batch of images. corresponding to a batch of images.
Returns: Returns:
torch.Tensor: shape (n, c), [batch_ind, x, y ...]. torch.Tensor: Region of interests in shape (n, c), where \
the channels are in order of [batch_ind, x, y ...].
""" """
rois_list = [] rois_list = []
for img_id, bboxes in enumerate(bbox_list): for img_id, bboxes in enumerate(bbox_list):
...@@ -54,11 +55,11 @@ def bbox3d2result(bboxes, scores, labels): ...@@ -54,11 +55,11 @@ def bbox3d2result(bboxes, scores, labels):
scores (torch.Tensor): Scores with shape of (n, ). scores (torch.Tensor): Scores with shape of (n, ).
Returns: Returns:
dict[str, torch.Tensor]: Bbox results in cpu mode. dict[str, torch.Tensor]: Bounding box results in cpu mode.
- boxes_3d (torch.Tensor): 3D boxes - boxes_3d (torch.Tensor): 3D boxes.
- scores (torch.Tensor): prediction scores - scores (torch.Tensor): Prediction scores.
- labels_3d (torch.Tensor): box labels - labels_3d (torch.Tensor): Box labels.
""" """
return dict( return dict(
boxes_3d=bboxes.to('cpu'), boxes_3d=bboxes.to('cpu'),
......
...@@ -212,14 +212,16 @@ def indoor_eval(gt_annos, ...@@ -212,14 +212,16 @@ def indoor_eval(gt_annos,
Evaluate the result of the detection. Evaluate the result of the detection.
Args: Args:
gt_annos (list[dict]): GT annotations. gt_annos (list[dict]): Ground truth annotations.
dt_annos (list[dict]): Detection annotations. the dict dt_annos (list[dict]): Detection annotations. the dict
includes the following keys includes the following keys
- labels_3d (torch.Tensor): Labels of boxes. - labels_3d (torch.Tensor): Labels of boxes.
- boxes_3d (BaseInstance3DBoxes): 3D bboxes in Depth coordinate. - boxes_3d (:obj:`BaseInstance3DBoxes`): \
3D bounding boxes in Depth coordinate.
- scores_3d (torch.Tensor): Scores of boxes. - scores_3d (torch.Tensor): Scores of boxes.
metric (list[float]): AP IoU thresholds. metric (list[float]): IoU thresholds for computing average precisions.
label2cat (dict): {label: cat}. label2cat (dict): Map from label to category.
logger (logging.Logger | str | None): The way to print the mAP logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None. summary. See `mmdet.utils.print_log()` for details. Default: None.
......
...@@ -16,18 +16,18 @@ def box3d_multiclass_nms(mlvl_bboxes, ...@@ -16,18 +16,18 @@ def box3d_multiclass_nms(mlvl_bboxes,
mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M). mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M).
M is the dimensions of boxes. M is the dimensions of boxes.
mlvl_bboxes_for_nms (torch.Tensor): Multi-level boxes with shape mlvl_bboxes_for_nms (torch.Tensor): Multi-level boxes with shape
(N, 4), N is the number of boxes. (N, 4). N is the number of boxes.
mlvl_scores (torch.Tensor): Multi-level boxes with shape mlvl_scores (torch.Tensor): Multi-level boxes with shape
(N, ), N is the number of boxes. (N, ). N is the number of boxes.
score_thr (float): Score thredhold to filter boxes with low score_thr (float): Score thredhold to filter boxes with low
confidence. confidence.
max_num (int): Maximum number of boxes will be kept. max_num (int): Maximum number of boxes will be kept.
cfg (dict): Config dict of NMS. cfg (dict): Configuration dict of NMS.
mlvl_dir_scores (torch.Tensor, optional): Multi-level scores mlvl_dir_scores (torch.Tensor, optional): Multi-level scores
of direction classifier. Defaults to None. of direction classifier. Defaults to None.
Returns: Returns:
tuple[torch.Tensor]: Return results after nms, including 3D tuple[torch.Tensor]: Return results after nms, including 3D \
bounding boxes, scores, labels and direction scores. bounding boxes, scores, labels and direction scores.
""" """
# do multi class nms # do multi class nms
......
...@@ -10,6 +10,7 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg): ...@@ -10,6 +10,7 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):
Args: Args:
aug_results (list[dict]): The dict of detection results. aug_results (list[dict]): The dict of detection results.
The dict contains the following keys The dict contains the following keys
- boxes_3d (:obj:`BaseInstance3DBoxes`): Detection bbox. - boxes_3d (:obj:`BaseInstance3DBoxes`): Detection bbox.
- scores_3d (torch.Tensor): Detection scores. - scores_3d (torch.Tensor): Detection scores.
- labels_3d (torch.Tensor): Predicted box labels. - labels_3d (torch.Tensor): Predicted box labels.
......
...@@ -265,13 +265,28 @@ class Custom3DDataset(Dataset): ...@@ -265,13 +265,28 @@ class Custom3DDataset(Dataset):
return ret_dict return ret_dict
def __len__(self): def __len__(self):
"""Return the length of data infos.
Returns:
int: Length of data infos.
"""
return len(self.data_infos) return len(self.data_infos)
def _rand_another(self, idx): def _rand_another(self, idx):
"""Randomly get another item with the same flag.
Returns:
int: Another index of item with the same flag.
"""
pool = np.where(self.flag == self.flag[idx])[0] pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool) return np.random.choice(pool)
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get item from infos according to the given index.
Returns:
dict: Data dictionary of the corresponding index.
"""
if self.test_mode: if self.test_mode:
return self.prepare_test_data(idx) return self.prepare_test_data(idx)
while True: while True:
...@@ -285,7 +300,7 @@ class Custom3DDataset(Dataset): ...@@ -285,7 +300,7 @@ class Custom3DDataset(Dataset):
"""Set flag according to image aspect ratio. """Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1, Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus are all otherwise group 0. In 3D datasets, they are all the same, thus
zeros. are all zeros.
""" """
self.flag = np.zeros(len(self), dtype=np.uint8) self.flag = np.zeros(len(self), dtype=np.uint8)
...@@ -6,6 +6,33 @@ from mmdet.datasets import DATASETS, CustomDataset ...@@ -6,6 +6,33 @@ from mmdet.datasets import DATASETS, CustomDataset
@DATASETS.register_module() @DATASETS.register_module()
class Kitti2DDataset(CustomDataset): class Kitti2DDataset(CustomDataset):
r"""KITTI 2D Dataset.
This class serves as the API for experiments on the `KITTI Dataset
<http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d>`_.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
box_type_3d (str, optional): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR'. Available options includes
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
"""
CLASSES = ('car', 'pedestrian', 'cyclist') CLASSES = ('car', 'pedestrian', 'cyclist')
""" """
...@@ -173,6 +200,16 @@ class Kitti2DDataset(CustomDataset): ...@@ -173,6 +200,16 @@ class Kitti2DDataset(CustomDataset):
return inds return inds
def reformat_bbox(self, outputs, out=None): def reformat_bbox(self, outputs, out=None):
"""Reformat bounding boxes to KITTI 2D styles.
Args:
outputs (list[np.ndarray]): List of arrays storing the inferenced
bounding boxes and scores.
out (str | None): The prefix of output file. Default: None.
Returns:
list[dict]: A list of dictionaries with the kitti 2D format.
"""
from mmdet3d.core.bbox.transforms import bbox2result_kitti2d from mmdet3d.core.bbox.transforms import bbox2result_kitti2d
sample_idx = [info['image']['image_idx'] for info in self.data_infos] sample_idx = [info['image']['image_idx'] for info in self.data_infos]
result_files = bbox2result_kitti2d(outputs, self.CLASSES, sample_idx, result_files = bbox2result_kitti2d(outputs, self.CLASSES, sample_idx,
......
...@@ -74,6 +74,14 @@ class KittiDataset(Custom3DDataset): ...@@ -74,6 +74,14 @@ class KittiDataset(Custom3DDataset):
self.pts_prefix = pts_prefix self.pts_prefix = pts_prefix
def _get_pts_filename(self, idx): def _get_pts_filename(self, idx):
"""Get point cloud filename according to the given index.
Args:
index (int): Index of the point cloud file to get.
Returns:
str: Name of the point cloud file.
"""
pts_filename = osp.join(self.root_split, self.pts_prefix, pts_filename = osp.join(self.root_split, self.pts_prefix,
f'{idx:06d}.bin') f'{idx:06d}.bin')
return pts_filename return pts_filename
...@@ -351,6 +359,19 @@ class KittiDataset(Custom3DDataset): ...@@ -351,6 +359,19 @@ class KittiDataset(Custom3DDataset):
class_names, class_names,
pklfile_prefix=None, pklfile_prefix=None,
submission_prefix=None): submission_prefix=None):
"""Convert 3D detection results to kitti format for evaluation and test
submission.
Args:
net_outputs (list[np.ndarray]): List of array storing the \
inferenced bounding boxes and scores.
class_names (list[String]): A list of class names.
pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file.
Returns:
list[dict]: A list of dictionaries with the kitti format.
"""
assert len(net_outputs) == len(self.data_infos) assert len(net_outputs) == len(self.data_infos)
if submission_prefix is not None: if submission_prefix is not None:
mmcv.mkdir_or_exist(submission_prefix) mmcv.mkdir_or_exist(submission_prefix)
...@@ -457,14 +478,14 @@ class KittiDataset(Custom3DDataset): ...@@ -457,14 +478,14 @@ class KittiDataset(Custom3DDataset):
submission. submission.
Args: Args:
net_outputs (list[np.ndarray]): list of array storing the net_outputs (list[np.ndarray]): List of array storing the \
bbox and score inferenced bounding boxes and scores.
class_nanes (list[String]): A list of class names class_names (list[String]): A list of class names.
pklfile_prefix (str | None): The prefix of pkl file. pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file. submission_prefix (str | None): The prefix of submission file.
Returns: Returns:
list[dict]: A list of dict have the kitti format list[dict]: A list of dictionaries have the kitti format
""" """
assert len(net_outputs) == len(self.data_infos) assert len(net_outputs) == len(self.data_infos)
...@@ -561,6 +582,28 @@ class KittiDataset(Custom3DDataset): ...@@ -561,6 +582,28 @@ class KittiDataset(Custom3DDataset):
return det_annos return det_annos
def convert_valid_bboxes(self, box_dict, info): def convert_valid_bboxes(self, box_dict, info):
"""Convert the predicted boxes into valid ones.
Args:
box_dict (dict): Box dictionaries to be converted.
- boxes_3d (:obj:`LiDARInstance3DBoxes`): 3D bounding boxes.
- scores_3d (torch.Tensor): Scores of boxes.
- labels_3d (torch.Tensor): Class labels of boxes.
info (dict): Data info.
Returns:
dict: Valid predicted boxes.
- bbox (np.ndarray): 2D bounding boxes.
- box3d_camera (np.ndarray): 3D bounding boxes in \
camera coordinate.
- box3d_lidar (np.ndarray): 3D bounding boxes in \
LiDAR coordinate.
- scores (np.ndarray): Scores of boxes.
- label_preds (np.ndarray): Class label predictions.
- sample_idx (int): Sample index.
"""
# TODO: refactor this function # TODO: refactor this function
box_preds = box_dict['boxes_3d'] box_preds = box_dict['boxes_3d']
scores = box_dict['scores_3d'] scores = box_dict['scores_3d']
......
...@@ -30,6 +30,15 @@ class DefaultFormatBundle(object): ...@@ -30,6 +30,15 @@ class DefaultFormatBundle(object):
return return
def __call__(self, results): def __call__(self, results):
"""Call function to transform and format common fields in results.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with
default bundle.
"""
if 'img' in results: if 'img' in results:
if isinstance(results['img'], list): if isinstance(results['img'], list):
# process multiple imgs in single frame # process multiple imgs in single frame
...@@ -132,6 +141,17 @@ class Collect3D(object): ...@@ -132,6 +141,17 @@ class Collect3D(object):
self.meta_keys = meta_keys self.meta_keys = meta_keys
def __call__(self, results): def __call__(self, results):
"""Call function to collect keys in results. The keys in ``meta_keys``
will be converted to :obj:`mmcv.DataContainer`.
Args:
results (dict): Result dict contains the data to collect.
Returns:
dict: The result dict contains the following keys
- keys in ``self.keys``
- ``img_metas``
"""
data = {} data = {}
img_metas = {} img_metas = {}
for key in self.meta_keys: for key in self.meta_keys:
...@@ -144,6 +164,7 @@ class Collect3D(object): ...@@ -144,6 +164,7 @@ class Collect3D(object):
return data return data
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
return self.__class__.__name__ + '(keys={}, meta_keys={})'.format( return self.__class__.__name__ + '(keys={}, meta_keys={})'.format(
self.keys, self.meta_keys) self.keys, self.meta_keys)
...@@ -171,6 +192,15 @@ class DefaultFormatBundle3D(DefaultFormatBundle): ...@@ -171,6 +192,15 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
self.with_label = with_label self.with_label = with_label
def __call__(self, results): def __call__(self, results):
"""Call function to transform and format common fields in results.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with
default bundle.
"""
# Format 3D data # Format 3D data
for key in [ for key in [
'voxels', 'coors', 'voxel_centers', 'num_points', 'points' 'voxels', 'coors', 'voxel_centers', 'num_points', 'points'
...@@ -220,6 +250,7 @@ class DefaultFormatBundle3D(DefaultFormatBundle): ...@@ -220,6 +250,7 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(class_names={}, '.format(self.class_names) repr_str += '(class_names={}, '.format(self.class_names)
repr_str += 'with_gt={}, with_label={})'.format( repr_str += 'with_gt={}, with_label={})'.format(
......
...@@ -17,6 +17,23 @@ class LoadMultiViewImageFromFiles(object): ...@@ -17,6 +17,23 @@ class LoadMultiViewImageFromFiles(object):
self.color_type = color_type self.color_type = color_type
def __call__(self, results): def __call__(self, results):
"""Call function to load multi-view image from files.
Args:
results (dict): Result dict containing multi-view image filenames.
Returns:
dict: The result dict containing the multi-view image data. \
Added keys and values are described below.
- filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays.
- img_shape (tuple[int]): Shape of multi-view image arrays.
- ori_shape (tuple[int]): Shape of original image arrays.
- pad_shape (tuple[int]): Shape of padded image arrays.
- scale_factor (float): Scale factor.
- img_norm_cfg (dict): Normalization configuration of images.
"""
filename = results['img_filename'] filename = results['img_filename']
img = np.stack( img = np.stack(
[mmcv.imread(name, self.color_type) for name in filename], axis=-1) [mmcv.imread(name, self.color_type) for name in filename], axis=-1)
...@@ -37,6 +54,7 @@ class LoadMultiViewImageFromFiles(object): ...@@ -37,6 +54,7 @@ class LoadMultiViewImageFromFiles(object):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
return "{} (to_float32={}, color_type='{}')".format( return "{} (to_float32={}, color_type='{}')".format(
self.__class__.__name__, self.to_float32, self.color_type) self.__class__.__name__, self.to_float32, self.color_type)
...@@ -65,6 +83,14 @@ class LoadPointsFromMultiSweeps(object): ...@@ -65,6 +83,14 @@ class LoadPointsFromMultiSweeps(object):
self.file_client = None self.file_client = None
def _load_points(self, pts_filename): def _load_points(self, pts_filename):
"""Private function to load point clouds data.
Args:
pts_filename (str): Filename of point clouds data.
Returns:
np.ndarray: An array containing point clouds data.
"""
if self.file_client is None: if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmcv.FileClient(**self.file_client_args)
try: try:
...@@ -79,6 +105,18 @@ class LoadPointsFromMultiSweeps(object): ...@@ -79,6 +105,18 @@ class LoadPointsFromMultiSweeps(object):
return points return points
def __call__(self, results): def __call__(self, results):
"""Call function to load multi-sweep point clouds from files.
Args:
results (dict): Result dict containing multi-sweep point cloud \
filenames.
Returns:
dict: The result dict containing the multi-sweep points data. \
Added key and value are described below.
- points (np.ndarray): Multi-sweep point cloud arrays.
"""
points = results['points'] points = results['points']
points[:, 3] /= 255 points[:, 3] /= 255
points[:, 4] = 0 points[:, 4] = 0
...@@ -103,6 +141,7 @@ class LoadPointsFromMultiSweeps(object): ...@@ -103,6 +141,7 @@ class LoadPointsFromMultiSweeps(object):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})' return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'
...@@ -121,6 +160,17 @@ class PointSegClassMapping(object): ...@@ -121,6 +160,17 @@ class PointSegClassMapping(object):
self.valid_cat_ids = valid_cat_ids self.valid_cat_ids = valid_cat_ids
def __call__(self, results): def __call__(self, results):
"""Call function to map original semantic class to valid category ids.
Args:
results (dict): Result dict containing point semantic masks.
Returns:
dict: The result dict containing the mapped category ids. \
Updated key and value are described below.
- pts_semantic_mask (np.ndarray): Mapped semantic masks.
"""
assert 'pts_semantic_mask' in results assert 'pts_semantic_mask' in results
pts_semantic_mask = results['pts_semantic_mask'] pts_semantic_mask = results['pts_semantic_mask']
neg_cls = len(self.valid_cat_ids) neg_cls = len(self.valid_cat_ids)
...@@ -136,6 +186,7 @@ class PointSegClassMapping(object): ...@@ -136,6 +186,7 @@ class PointSegClassMapping(object):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(valid_cat_ids={})'.format(self.valid_cat_ids) repr_str += '(valid_cat_ids={})'.format(self.valid_cat_ids)
return repr_str return repr_str
...@@ -153,6 +204,17 @@ class NormalizePointsColor(object): ...@@ -153,6 +204,17 @@ class NormalizePointsColor(object):
self.color_mean = color_mean self.color_mean = color_mean
def __call__(self, results): def __call__(self, results):
"""Call function to normalize color of points.
Args:
results (dict): Result dict containing point clouds data.
Returns:
dict: The result dict containing the normalized points. \
Updated key and value are described below.
- points (np.ndarray): Points after color normalization.
"""
points = results['points'] points = results['points']
assert points.shape[1] >= 6,\ assert points.shape[1] >= 6,\
f'Expect points have channel >=6, got {points.shape[1]}' f'Expect points have channel >=6, got {points.shape[1]}'
...@@ -161,6 +223,7 @@ class NormalizePointsColor(object): ...@@ -161,6 +223,7 @@ class NormalizePointsColor(object):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(color_mean={})'.format(self.color_mean) repr_str += '(color_mean={})'.format(self.color_mean)
return repr_str return repr_str
...@@ -201,6 +264,14 @@ class LoadPointsFromFile(object): ...@@ -201,6 +264,14 @@ class LoadPointsFromFile(object):
self.file_client = None self.file_client = None
def _load_points(self, pts_filename): def _load_points(self, pts_filename):
"""Private function to load point clouds data.
Args:
pts_filename (str): Filename of point clouds data.
Returns:
np.ndarray: An array containing point clouds data.
"""
if self.file_client is None: if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmcv.FileClient(**self.file_client_args)
try: try:
...@@ -215,6 +286,17 @@ class LoadPointsFromFile(object): ...@@ -215,6 +286,17 @@ class LoadPointsFromFile(object):
return points return points
def __call__(self, results): def __call__(self, results):
"""Call function to load points data from file.
Args:
results (dict): Result dict containing point clouds data.
Returns:
dict: The result dict containing the point clouds data. \
Added key and value are described below.
- points (np.ndarray): Point clouds data.
"""
pts_filename = results['pts_filename'] pts_filename = results['pts_filename']
points = self._load_points(pts_filename) points = self._load_points(pts_filename)
points = points.reshape(-1, self.load_dim) points = points.reshape(-1, self.load_dim)
...@@ -228,6 +310,7 @@ class LoadPointsFromFile(object): ...@@ -228,6 +310,7 @@ class LoadPointsFromFile(object):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ + '(' repr_str = self.__class__.__name__ + '('
repr_str += 'shift_height={}, '.format(self.shift_height) repr_str += 'shift_height={}, '.format(self.shift_height)
repr_str += 'file_client_args={}), '.format(self.file_client_args) repr_str += 'file_client_args={}), '.format(self.file_client_args)
...@@ -291,15 +374,39 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -291,15 +374,39 @@ class LoadAnnotations3D(LoadAnnotations):
self.with_seg_3d = with_seg_3d self.with_seg_3d = with_seg_3d
def _load_bboxes_3d(self, results): def _load_bboxes_3d(self, results):
"""Private function to load 3D bounding box annotations.
Args:
results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.
Returns:
dict: The dict containing loaded 3D bounding box annotations.
"""
results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d'] results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d']
results['bbox3d_fields'].append('gt_bboxes_3d') results['bbox3d_fields'].append('gt_bboxes_3d')
return results return results
def _load_labels_3d(self, results): def _load_labels_3d(self, results):
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.
Returns:
dict: The dict containing loaded label annotations.
"""
results['gt_labels_3d'] = results['ann_info']['gt_labels_3d'] results['gt_labels_3d'] = results['ann_info']['gt_labels_3d']
return results return results
def _load_masks_3d(self, results): def _load_masks_3d(self, results):
"""Private function to load 3D mask annotations.
Args:
results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.
Returns:
dict: The dict containing loaded 3D mask annotations.
"""
pts_instance_mask_path = results['ann_info']['pts_instance_mask_path'] pts_instance_mask_path = results['ann_info']['pts_instance_mask_path']
if self.file_client is None: if self.file_client is None:
...@@ -317,6 +424,14 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -317,6 +424,14 @@ class LoadAnnotations3D(LoadAnnotations):
return results return results
def _load_semantic_seg_3d(self, results): def _load_semantic_seg_3d(self, results):
"""Private function to load 3D semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.
Returns:
dict: The dict containing the semantic segmentation annotations.
"""
pts_semantic_mask_path = results['ann_info']['pts_semantic_mask_path'] pts_semantic_mask_path = results['ann_info']['pts_semantic_mask_path']
if self.file_client is None: if self.file_client is None:
...@@ -335,6 +450,15 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -335,6 +450,15 @@ class LoadAnnotations3D(LoadAnnotations):
return results return results
def __call__(self, results): def __call__(self, results):
"""Call function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.
Returns:
dict: The dict containing loaded 3D bounding box, label, mask and
semantic segmentation annotations.
"""
results = super().__call__(results) results = super().__call__(results)
if self.with_bbox_3d: if self.with_bbox_3d:
results = self._load_bboxes_3d(results) results = self._load_bboxes_3d(results)
...@@ -350,6 +474,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -350,6 +474,7 @@ class LoadAnnotations3D(LoadAnnotations):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
indent_str = ' ' indent_str = ' '
repr_str = self.__class__.__name__ + '(\n' repr_str = self.__class__.__name__ + '(\n'
repr_str += f'{indent_str}with_bbox_3d={self.with_bbox_3d}, ' repr_str += f'{indent_str}with_bbox_3d={self.with_bbox_3d}, '
......
...@@ -63,6 +63,15 @@ class MultiScaleFlipAug3D(object): ...@@ -63,6 +63,15 @@ class MultiScaleFlipAug3D(object):
'flip has no effect when RandomFlip is not in transforms') 'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results): def __call__(self, results):
"""Call function to augment common fields in results.
Args:
results (dict): Result dict contains the data to augment.
Returns:
dict: The result dict contains the data that is augmented with \
different scales and flips.
"""
aug_data = [] aug_data = []
flip_aug = [False, True] if self.flip else [False] flip_aug = [False, True] if self.flip else [False]
pcd_horizontal_flip_aug = [False, True] \ pcd_horizontal_flip_aug = [False, True] \
...@@ -97,6 +106,7 @@ class MultiScaleFlipAug3D(object): ...@@ -97,6 +106,7 @@ class MultiScaleFlipAug3D(object):
return aug_data_dict return aug_data_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, ' repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, ' repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
......
...@@ -46,6 +46,16 @@ class RandomFlip3D(RandomFlip): ...@@ -46,6 +46,16 @@ class RandomFlip3D(RandomFlip):
(int, float)) and 0 <= flip_ratio_bev_vertical <= 1 (int, float)) and 0 <= flip_ratio_bev_vertical <= 1
def random_flip_data_3d(self, input_dict, direction='horizontal'): def random_flip_data_3d(self, input_dict, direction='horizontal'):
"""Flip 3D data randomly.
Args:
input_dict (dict): Result dict from loading pipeline.
direction (str): Flip direction. Default: horizontal.
Returns:
dict: Flipped results, 'points', 'bbox3d_fields' keys are \
updated in the result dict.
"""
assert direction in ['horizontal', 'vertical'] assert direction in ['horizontal', 'vertical']
if len(input_dict['bbox3d_fields']) == 0: # test mode if len(input_dict['bbox3d_fields']) == 0: # test mode
input_dict['bbox3d_fields'].append('empty_box3d') input_dict['bbox3d_fields'].append('empty_box3d')
...@@ -57,6 +67,17 @@ class RandomFlip3D(RandomFlip): ...@@ -57,6 +67,17 @@ class RandomFlip3D(RandomFlip):
direction, points=input_dict['points']) direction, points=input_dict['points'])
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to flip points, values in the ``bbox3d_fields`` and \
also flip 2D image and its annotations.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'flip', 'flip_direction', \
'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added \
into result dict.
"""
# filp 2D image and its annotations # filp 2D image and its annotations
super(RandomFlip3D, self).__call__(input_dict) super(RandomFlip3D, self).__call__(input_dict)
...@@ -80,6 +101,7 @@ class RandomFlip3D(RandomFlip): ...@@ -80,6 +101,7 @@ class RandomFlip3D(RandomFlip):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(sync_2d={},'.format(self.sync_2d) repr_str += '(sync_2d={},'.format(self.sync_2d)
repr_str += '(flip_ratio_bev_horizontal={},'.format( repr_str += '(flip_ratio_bev_horizontal={},'.format(
...@@ -108,11 +130,30 @@ class ObjectSample(object): ...@@ -108,11 +130,30 @@ class ObjectSample(object):
@staticmethod @staticmethod
def remove_points_in_boxes(points, boxes): def remove_points_in_boxes(points, boxes):
"""Remove the points in the sampled bounding boxes.
Args:
points (np.ndarray): Input point cloud array.
boxes (np.ndarray): Sampled ground truth boxes.
Returns:
np.ndarray: Points with those in the boxes removed.
"""
masks = box_np_ops.points_in_rbbox(points, boxes) masks = box_np_ops.points_in_rbbox(points, boxes)
points = points[np.logical_not(masks.any(-1))] points = points[np.logical_not(masks.any(-1))]
return points return points
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to sample ground truth objects to the data.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after object sampling augmentation, \
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated \
in the result dict.
"""
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
...@@ -163,6 +204,7 @@ class ObjectSample(object): ...@@ -163,6 +204,7 @@ class ObjectSample(object):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
return self.__class__.__name__ return self.__class__.__name__
...@@ -193,6 +235,15 @@ class ObjectNoise(object): ...@@ -193,6 +235,15 @@ class ObjectNoise(object):
self.num_try = num_try self.num_try = num_try
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to apply noise to each ground truth in the scene.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after adding noise to each object, \
'points', 'gt_bboxes_3d' keys are updated in the result dict.
"""
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
points = input_dict['points'] points = input_dict['points']
...@@ -211,6 +262,7 @@ class ObjectNoise(object): ...@@ -211,6 +262,7 @@ class ObjectNoise(object):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(num_try={},'.format(self.num_try) repr_str += '(num_try={},'.format(self.num_try)
repr_str += ' translation_std={},'.format(self.translation_std) repr_str += ' translation_std={},'.format(self.translation_std)
...@@ -247,6 +299,16 @@ class GlobalRotScaleTrans(object): ...@@ -247,6 +299,16 @@ class GlobalRotScaleTrans(object):
self.shift_height = shift_height self.shift_height = shift_height
def _trans_bbox_points(self, input_dict): def _trans_bbox_points(self, input_dict):
"""Private function to translate bounding boxes and points.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after translation, 'points', 'pcd_trans' \
and keys in input_dict['bbox3d_fields'] are updated \
in the result dict.
"""
if not isinstance(self.translation_std, (list, tuple, np.ndarray)): if not isinstance(self.translation_std, (list, tuple, np.ndarray)):
translation_std = [ translation_std = [
self.translation_std, self.translation_std, self.translation_std, self.translation_std,
...@@ -263,6 +325,16 @@ class GlobalRotScaleTrans(object): ...@@ -263,6 +325,16 @@ class GlobalRotScaleTrans(object):
input_dict[key].translate(trans_factor) input_dict[key].translate(trans_factor)
def _rot_bbox_points(self, input_dict): def _rot_bbox_points(self, input_dict):
"""Private function to rotate bounding boxes and points.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after rotation, 'points', 'pcd_rotation' \
and keys in input_dict['bbox3d_fields'] are updated \
in the result dict.
"""
rotation = self.rot_range rotation = self.rot_range
if not isinstance(rotation, list): if not isinstance(rotation, list):
rotation = [-rotation, rotation] rotation = [-rotation, rotation]
...@@ -276,6 +348,15 @@ class GlobalRotScaleTrans(object): ...@@ -276,6 +348,15 @@ class GlobalRotScaleTrans(object):
input_dict['pcd_rotation'] = rot_mat_T input_dict['pcd_rotation'] = rot_mat_T
def _scale_bbox_points(self, input_dict): def _scale_bbox_points(self, input_dict):
"""Private function to scale bounding boxes and points.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after scaling, 'points'and keys in \
input_dict['bbox3d_fields'] are updated in the result dict.
"""
scale = input_dict['pcd_scale_factor'] scale = input_dict['pcd_scale_factor']
input_dict['points'][:, :3] *= scale input_dict['points'][:, :3] *= scale
if self.shift_height: if self.shift_height:
...@@ -285,11 +366,31 @@ class GlobalRotScaleTrans(object): ...@@ -285,11 +366,31 @@ class GlobalRotScaleTrans(object):
input_dict[key].scale(scale) input_dict[key].scale(scale)
def _random_scale(self, input_dict): def _random_scale(self, input_dict):
"""Private function to randomly set the scale factor.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after scaling, 'pcd_scale_factor' are updated \
in the result dict.
"""
scale_factor = np.random.uniform(self.scale_ratio_range[0], scale_factor = np.random.uniform(self.scale_ratio_range[0],
self.scale_ratio_range[1]) self.scale_ratio_range[1])
input_dict['pcd_scale_factor'] = scale_factor input_dict['pcd_scale_factor'] = scale_factor
def __call__(self, input_dict): def __call__(self, input_dict):
"""Private function to rotate, scale and translate bounding boxes and \
points.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after scaling, 'points', 'pcd_rotation',
'pcd_scale_factor', 'pcd_trans' and keys in \
input_dict['bbox3d_fields'] are updated in the result dict.
"""
self._rot_bbox_points(input_dict) self._rot_bbox_points(input_dict)
if 'pcd_scale_factor' not in input_dict: if 'pcd_scale_factor' not in input_dict:
...@@ -300,6 +401,7 @@ class GlobalRotScaleTrans(object): ...@@ -300,6 +401,7 @@ class GlobalRotScaleTrans(object):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(rot_range={},'.format(self.rot_range) repr_str += '(rot_range={},'.format(self.rot_range)
repr_str += ' scale_ratio_range={},'.format(self.scale_ratio_range) repr_str += ' scale_ratio_range={},'.format(self.scale_ratio_range)
...@@ -310,8 +412,18 @@ class GlobalRotScaleTrans(object): ...@@ -310,8 +412,18 @@ class GlobalRotScaleTrans(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class PointShuffle(object): class PointShuffle(object):
"""Shuffle input points."""
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to shuffle points.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points' keys are updated \
in the result dict.
"""
np.random.shuffle(input_dict['points']) np.random.shuffle(input_dict['points'])
return input_dict return input_dict
...@@ -321,12 +433,26 @@ class PointShuffle(object): ...@@ -321,12 +433,26 @@ class PointShuffle(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class ObjectRangeFilter(object): class ObjectRangeFilter(object):
"""Filter objects by the range.
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range): def __init__(self, point_cloud_range):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
self.bev_range = self.pcd_range[[0, 1, 3, 4]] self.bev_range = self.pcd_range[[0, 1, 3, 4]]
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to filter objects by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
keys are updated in the result dict.
"""
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
mask = gt_bboxes_3d.in_range_bev(self.bev_range) mask = gt_bboxes_3d.in_range_bev(self.bev_range)
...@@ -345,6 +471,7 @@ class ObjectRangeFilter(object): ...@@ -345,6 +471,7 @@ class ObjectRangeFilter(object):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(point_cloud_range={})'.format(self.pcd_range.tolist()) repr_str += '(point_cloud_range={})'.format(self.pcd_range.tolist())
return repr_str return repr_str
...@@ -352,12 +479,26 @@ class ObjectRangeFilter(object): ...@@ -352,12 +479,26 @@ class ObjectRangeFilter(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class PointsRangeFilter(object): class PointsRangeFilter(object):
"""Filter points by the range.
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range): def __init__(self, point_cloud_range):
self.pcd_range = np.array( self.pcd_range = np.array(
point_cloud_range, dtype=np.float32)[np.newaxis, :] point_cloud_range, dtype=np.float32)[np.newaxis, :]
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to filter points by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points' keys are updated \
in the result dict.
"""
points = input_dict['points'] points = input_dict['points']
points_mask = ((points[:, :3] >= self.pcd_range[:, :3]) points_mask = ((points[:, :3] >= self.pcd_range[:, :3])
& (points[:, :3] < self.pcd_range[:, 3:])) & (points[:, :3] < self.pcd_range[:, 3:]))
...@@ -367,6 +508,7 @@ class PointsRangeFilter(object): ...@@ -367,6 +508,7 @@ class PointsRangeFilter(object):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(point_cloud_range={})'.format(self.pcd_range.tolist()) repr_str += '(point_cloud_range={})'.format(self.pcd_range.tolist())
return repr_str return repr_str
...@@ -385,6 +527,15 @@ class ObjectNameFilter(object): ...@@ -385,6 +527,15 @@ class ObjectNameFilter(object):
self.labels = list(range(len(self.classes))) self.labels = list(range(len(self.classes)))
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to filter objects by their names.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
keys are updated in the result dict.
"""
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
gt_bboxes_mask = np.array([n in self.labels for n in gt_labels_3d], gt_bboxes_mask = np.array([n in self.labels for n in gt_labels_3d],
dtype=np.bool_) dtype=np.bool_)
...@@ -394,6 +545,7 @@ class ObjectNameFilter(object): ...@@ -394,6 +545,7 @@ class ObjectNameFilter(object):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(classes={self.classes})' repr_str += f'(classes={self.classes})'
return repr_str return repr_str
...@@ -444,6 +596,15 @@ class IndoorPointSample(object): ...@@ -444,6 +596,15 @@ class IndoorPointSample(object):
return points[choices] return points[choices]
def __call__(self, results): def __call__(self, results):
"""Call function to sample points to in indoor scenes.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results['points'] points = results['points']
points, choices = self.points_random_sampling( points, choices = self.points_random_sampling(
points, self.num_points, return_choices=True) points, self.num_points, return_choices=True)
...@@ -460,6 +621,7 @@ class IndoorPointSample(object): ...@@ -460,6 +621,7 @@ class IndoorPointSample(object):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(num_points={})'.format(self.num_points) repr_str += '(num_points={})'.format(self.num_points)
return repr_str return repr_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment