Commit 9419a722 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'fix-api-doc' into 'master'

Update docstrings

See merge request open-mmlab/mmdet.3d!138
parents 6d189b92 5d9682a2
...@@ -145,8 +145,6 @@ class PointSegClassMapping(object): ...@@ -145,8 +145,6 @@ class PointSegClassMapping(object):
class NormalizePointsColor(object): class NormalizePointsColor(object):
"""Normalize color of points. """Normalize color of points.
Normalize color of the points.
Args: Args:
color_mean (list[float]): Mean color of the point cloud. color_mean (list[float]): Mean color of the point cloud.
""" """
......
...@@ -19,7 +19,7 @@ class MultiScaleFlipAug3D(object): ...@@ -19,7 +19,7 @@ class MultiScaleFlipAug3D(object):
flip_direction (str | list[str]): Flip augmentation directions flip_direction (str | list[str]): Flip augmentation directions
for images, options are "horizontal" and "vertical". for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when flip == False. be applied. It has no effect when ``flip == False``.
Default: "horizontal". Default: "horizontal".
pcd_horizontal_flip (bool): Whether apply horizontal flip augmentation pcd_horizontal_flip (bool): Whether apply horizontal flip augmentation
to point cloud. Default: True. Note that it works only when to point cloud. Default: True. Note that it works only when
......
...@@ -13,9 +13,8 @@ class ScanNetDataset(Custom3DDataset): ...@@ -13,9 +13,8 @@ class ScanNetDataset(Custom3DDataset):
This class serves as the API for experiments on the ScanNet Dataset. This class serves as the API for experiments on the ScanNet Dataset.
Please refer to `<https://github.com/ScanNet/ScanNet>`_ Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
for data downloading. It is recommended to symlink the dataset root to for data downloading.
$MMDETECTION3D/data and organize them as the doc shows.
Args: Args:
data_root (str): Path of dataset root. data_root (str): Path of dataset root.
...@@ -70,10 +69,9 @@ class ScanNetDataset(Custom3DDataset): ...@@ -70,10 +69,9 @@ class ScanNetDataset(Custom3DDataset):
index (int): Index of the annotation data to get. index (int): Index of the annotation data to get.
Returns: Returns:
dict: Standard annotation dictionary dict: annotation information consists of the following keys:
consists of the data information.
- gt_bboxes_3d (:obj:`DepthInstance3DBoxes`): - gt_bboxes_3d (:obj:`DepthInstance3DBoxes`): \
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths - gt_labels_3d (np.ndarray): labels of ground truths
- pts_instance_mask_path (str): path of instance masks - pts_instance_mask_path (str): path of instance masks
......
...@@ -9,13 +9,12 @@ from .custom_3d import Custom3DDataset ...@@ -9,13 +9,12 @@ from .custom_3d import Custom3DDataset
@DATASETS.register_module() @DATASETS.register_module()
class SUNRGBDDataset(Custom3DDataset): class SUNRGBDDataset(Custom3DDataset):
"""SUNRGBD Dataset. r"""SUNRGBD Dataset.
This class serves as the API for experiments on the SUNRGBD Dataset. This class serves as the API for experiments on the SUNRGBD Dataset.
Please refer to `<http://rgbd.cs.princeton.edu/challenge.html>`_for See the `download page <http://rgbd.cs.princeton.edu/challenge.html>`_
data downloading. It is recommended to symlink the dataset root to for data downloading.
$MMDETECTION3D/data and organize them as the doc shows.
Args: Args:
data_root (str): Path of dataset root. data_root (str): Path of dataset root.
...@@ -68,10 +67,9 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -68,10 +67,9 @@ class SUNRGBDDataset(Custom3DDataset):
index (int): Index of the annotation data to get. index (int): Index of the annotation data to get.
Returns: Returns:
dict: Standard annotation dictionary dict: annotation information consists of the following keys:
consists of the data information.
- gt_bboxes_3d (:obj:``DepthInstance3DBoxes``): - gt_bboxes_3d (:obj:`DepthInstance3DBoxes`): \
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths - gt_labels_3d (np.ndarray): labels of ground truths
- pts_instance_mask_path (str): path of instance masks - pts_instance_mask_path (str): path of instance masks
......
...@@ -121,11 +121,11 @@ class PointNet2SASSG(nn.Module): ...@@ -121,11 +121,11 @@ class PointNet2SASSG(nn.Module):
Returns: Returns:
dict[str, list[torch.Tensor]]: outputs after SA and FP modules. dict[str, list[torch.Tensor]]: outputs after SA and FP modules.
- fp_xyz (list[torch.Tensor]): contains the coordinates of - fp_xyz (list[torch.Tensor]): contains the coordinates of \
each fp features. each fp features.
- fp_features (list[torch.Tensor]): contains the features - fp_features (list[torch.Tensor]): contains the features \
from each Feature Propagate Layers. from each Feature Propagate Layers.
- fp_indices (list[torch.Tensor]): contains indices of the - fp_indices (list[torch.Tensor]): contains indices of the \
input points. input points.
""" """
xyz, features = self._split_point_feats(points) xyz, features = self._split_point_feats(points)
......
...@@ -140,8 +140,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -140,8 +140,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
x (torch.Tensor): Input features. x (torch.Tensor): Input features.
Returns: Returns:
tuple[torch.Tensor]: Contain score of each class, bbox predictions tuple[torch.Tensor]: Contain score of each class, bbox \
and class predictions of direction. regression and direction classification predictions.
""" """
cls_score = self.conv_cls(x) cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x) bbox_pred = self.conv_reg(x)
...@@ -158,7 +158,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -158,7 +158,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
features produced by FPN. features produced by FPN.
Returns: Returns:
tuple[list[torch.Tensor]]: Multi-level class score, bbox tuple[list[torch.Tensor]]: Multi-level class score, bbox \
and direction predictions. and direction predictions.
""" """
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
...@@ -172,7 +172,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -172,7 +172,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
device (str): device of current module device (str): device of current module
Returns: Returns:
list[list[torch.Tensor]]: anchors of each image, valid flags list[list[torch.Tensor]]: anchors of each image, valid flags \
of each image of each image
""" """
num_imgs = len(input_metas) num_imgs = len(input_metas)
...@@ -202,7 +202,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -202,7 +202,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
num_total_samples (int): The number of valid samples. num_total_samples (int): The number of valid samples.
Returns: Returns:
tuple[torch.Tensor]: losses of class, bbox tuple[torch.Tensor]: losses of class, bbox \
and direction, respectively. and direction, respectively.
""" """
# classification loss # classification loss
...@@ -251,14 +251,14 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -251,14 +251,14 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
"""Convert the rotation difference to difference in sine function. """Convert the rotation difference to difference in sine function.
Args: Args:
boxes1 (torch.Tensor): shape (NxC), where C>=7 and boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7
the 7th dimension is rotation dimension and the 7th dimension is rotation dimension.
boxes2 (torch.Tensor): shape (NxC), where C>=7 and the 7th boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and
dimension is rotation dimension the 7th dimension is rotation dimension.
Returns: Returns:
tuple[torch.Tensor]: boxes1 and boxes2 whose 7th dimensions tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th \
are changed dimensions are changed.
""" """
rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos( rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
boxes2[..., 6:7]) boxes2[..., 6:7])
...@@ -293,12 +293,12 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -293,12 +293,12 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
which bounding. which bounding.
Returns: Returns:
dict[str, list[torch.Tensor]]: Classification, bbox, and direction dict[str, list[torch.Tensor]]: Classification, bbox, and \
losses of each level. direction losses of each level.
- loss_cls (list[torch.Tensor]): Classification losses. - loss_cls (list[torch.Tensor]): Classification losses.
- loss_bbox (list[torch.Tensor]): Box regression losses. - loss_bbox (list[torch.Tensor]): Box regression losses.
- loss_dir (list[torch.Tensor]): Direction classification - loss_dir (list[torch.Tensor]): Direction classification \
losses. losses.
""" """
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
......
...@@ -104,12 +104,12 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -104,12 +104,12 @@ class PartA2RPNHead(Anchor3DHead):
which bounding. which bounding.
Returns: Returns:
dict[str, list[torch.Tensor]]: Classification, bbox, and direction dict[str, list[torch.Tensor]]: Classification, bbox, and \
losses of each level. direction losses of each level.
- loss_rpn_cls (list[torch.Tensor]): Classification losses. - loss_rpn_cls (list[torch.Tensor]): Classification losses.
- loss_rpn_bbox (list[torch.Tensor]): Box regression losses. - loss_rpn_bbox (list[torch.Tensor]): Box regression losses.
- loss_rpn_dir (list[torch.Tensor]): Direction classification - loss_rpn_dir (list[torch.Tensor]): Direction classification \
losses. losses.
""" """
loss_dict = super().loss(cls_scores, bbox_preds, dir_cls_preds, loss_dict = super().loss(cls_scores, bbox_preds, dir_cls_preds,
...@@ -143,7 +143,7 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -143,7 +143,7 @@ class PartA2RPNHead(Anchor3DHead):
rescale (list[torch.Tensor]): whether th rescale bbox. rescale (list[torch.Tensor]): whether th rescale bbox.
Returns: Returns:
dict: Predictions of single batch. Contain the keys: dict: Predictions of single batch containing the following keys:
- boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes. - boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores_3d (torch.Tensor): Score of each bbox. - scores_3d (torch.Tensor): Score of each bbox.
......
...@@ -15,9 +15,7 @@ from mmdet.models import HEADS ...@@ -15,9 +15,7 @@ from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
class VoteHead(nn.Module): class VoteHead(nn.Module):
"""Bbox head of Votenet. r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_.
https://arxiv.org/pdf/1904.09664.pdf
Args: Args:
num_classes (int): The number of class. num_classes (int): The number of class.
...@@ -113,11 +111,13 @@ class VoteHead(nn.Module): ...@@ -113,11 +111,13 @@ class VoteHead(nn.Module):
def forward(self, feat_dict, sample_mod): def forward(self, feat_dict, sample_mod):
"""Forward pass. """Forward pass.
The forward of VoteHead is devided into 4 steps: Note:
1. Generate vote_points from seed_points. The forward of VoteHead is devided into 4 steps:
2. Aggregate vote_points.
3. Predict bbox and score. 1. Generate vote_points from seed_points.
4. Decode predictions. 2. Aggregate vote_points.
3. Predict bbox and score.
4. Decode predictions.
Args: Args:
feat_dict (dict): feature dict from backbone. feat_dict (dict): feature dict from backbone.
......
...@@ -26,14 +26,15 @@ def chamfer_distance(src, ...@@ -26,14 +26,15 @@ def chamfer_distance(src,
The valid reduction method are 'none', 'sum' or 'mean'. The valid reduction method are 'none', 'sum' or 'mean'.
Returns: Returns:
tuple: Source and Destination loss with indices. tuple: Source and Destination loss with the corresponding indices.
- loss_src (torch.Tensor): The min distance
- loss_src (torch.Tensor): The min distance \
from source to destination. from source to destination.
- loss_dst (torch.Tensor): The min distance - loss_dst (torch.Tensor): The min distance \
from destination to source. from destination to source.
- indices1 (torch.Tensor): Index the min distance point - indices1 (torch.Tensor): Index the min distance point \
for each point in source to destination. for each point in source to destination.
- indices2 (torch.Tensor): Index the min distance point - indices2 (torch.Tensor): Index the min distance point \
for each point in destination to source. for each point in destination to source.
""" """
...@@ -123,10 +124,11 @@ class ChamferDistance(nn.Module): ...@@ -123,10 +124,11 @@ class ChamferDistance(nn.Module):
Defaults to False. Defaults to False.
Returns: Returns:
tuple[torch.Tensor]: If ``return_indices=True``, return losses of tuple[torch.Tensor]: If ``return_indices=True``, return losses of \
source and target with their corresponding indices in the order source and target with their corresponding indices in the \
of (loss_source, loss_target, indices1, indices2). If order of ``(loss_source, loss_target, indices1, indices2)``. \
``return_indices=False``, return (loss_source, loss_target). If ``return_indices=False``, return \
``(loss_source, loss_target)``.
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
......
...@@ -7,9 +7,7 @@ from ..registry import MIDDLE_ENCODERS ...@@ -7,9 +7,7 @@ from ..registry import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module() @MIDDLE_ENCODERS.register_module()
class SparseEncoder(nn.Module): class SparseEncoder(nn.Module):
"""Sparse encoder for Second. r"""Sparse encoder for SECOND and Part-A2.
See https://arxiv.org/abs/1907.03670 for more detials.
Args: Args:
in_channels (int): the number of input channels in_channels (int): the number of input channels
......
...@@ -8,9 +8,9 @@ from ..registry import MIDDLE_ENCODERS ...@@ -8,9 +8,9 @@ from ..registry import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module() @MIDDLE_ENCODERS.register_module()
class SparseUNet(nn.Module): class SparseUNet(nn.Module):
"""SparseUNet for PartA^2. r"""SparseUNet for PartA^2.
See https://arxiv.org/abs/1907.03670 for more detials. See the `paper <https://arxiv.org/abs/1907.03670>`_ for more detials.
Args: Args:
in_channels (int): the number of input channels in_channels (int): the number of input channels
...@@ -95,12 +95,13 @@ class SparseUNet(nn.Module): ...@@ -95,12 +95,13 @@ class SparseUNet(nn.Module):
"""Forward of SparseUNet. """Forward of SparseUNet.
Args: Args:
voxel_features (torch.float32): shape [N, C] voxel_features (torch.float32): Voxel features in shape [N, C].
coors (torch.int32): shape [N, 4](batch_idx, z_idx, y_idx, x_idx) coors (torch.int32): Coordinates in shape [N, 4],
batch_size (int): batch size the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
batch_size (int): Batch size.
Returns: Returns:
dict: backbone features dict[str, torch.Tensor]: Backbone features.
""" """
coors = coors.int() coors = coors.int()
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors, input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors,
...@@ -147,14 +148,14 @@ class SparseUNet(nn.Module): ...@@ -147,14 +148,14 @@ class SparseUNet(nn.Module):
"""Forward of upsample and residual block. """Forward of upsample and residual block.
Args: Args:
x_lateral (:obj:`SparseConvTensor`): lateral tensor x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
x_bottom (:obj:`SparseConvTensor`): feature from bottom layer x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
lateral_layer (SparseBasicBlock): convolution for lateral tensor lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
merge_layer (SparseSequential): convolution for merging features merge_layer (SparseSequential): Convolution for merging features.
upsample_layer (SparseSequential): convolution for upsampling upsample_layer (SparseSequential): Convolution for upsampling.
Returns: Returns:
:obj:`SparseConvTensor`: upsampled feature :obj:`SparseConvTensor`: Upsampled feature.
""" """
x = lateral_layer(x_lateral) x = lateral_layer(x_lateral)
x.features = torch.cat((x_bottom.features, x.features), dim=1) x.features = torch.cat((x_bottom.features, x.features), dim=1)
...@@ -169,11 +170,12 @@ class SparseUNet(nn.Module): ...@@ -169,11 +170,12 @@ class SparseUNet(nn.Module):
"""reduce channel for element-wise addition. """reduce channel for element-wise addition.
Args: Args:
x (:obj:`SparseConvTensor`): x.features (N, C1) x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
out_channels (int): the number of channel after reduction are in shape (N, C1).
out_channels (int): The number of channel after reduction.
Returns: Returns:
:obj:`SparseConvTensor`: channel reduced feature :obj:`SparseConvTensor`: Channel reduced feature.
""" """
features = x.features features = x.features
n, in_channels = features.shape n, in_channels = features.shape
...@@ -187,12 +189,12 @@ class SparseUNet(nn.Module): ...@@ -187,12 +189,12 @@ class SparseUNet(nn.Module):
"""make encoder layers using sparse convs. """make encoder layers using sparse convs.
Args: Args:
make_block (method): a bounded function to build blocks make_block (method): A bounded function to build blocks.
norm_cfg (dict[str]): config of normalization layer norm_cfg (dict[str]): Config of normalization layer.
in_channels (int): the number of encoder input channels in_channels (int): The number of encoder input channels.
Returns: Returns:
int: the number of encoder output channels int: the number of encoder output channels.
""" """
self.encoder_layers = spconv.SparseSequential() self.encoder_layers = spconv.SparseSequential()
...@@ -233,12 +235,12 @@ class SparseUNet(nn.Module): ...@@ -233,12 +235,12 @@ class SparseUNet(nn.Module):
"""make decoder layers using sparse convs. """make decoder layers using sparse convs.
Args: Args:
make_block (method): a bounded function to build blocks make_block (method): A bounded function to build blocks.
norm_cfg (dict[str]): config of normalization layer norm_cfg (dict[str]): Config of normalization layer.
in_channels (int): the number of encoder input channels in_channels (int): The number of encoder input channels.
Returns: Returns:
int: the number of encoder output channels int: The number of encoder output channels.
""" """
block_num = len(self.decoder_channels) block_num = len(self.decoder_channels)
for i, block_channels in enumerate(self.decoder_channels): for i, block_channels in enumerate(self.decoder_channels):
......
...@@ -23,7 +23,7 @@ class VoteModule(nn.Module): ...@@ -23,7 +23,7 @@ class VoteModule(nn.Module):
Default: dict(type='BN1d'). Default: dict(type='BN1d').
norm_feats (bool): Whether to normalize features. norm_feats (bool): Whether to normalize features.
Default: True. Default: True.
vote_loss (dict): config of vote loss. vote_loss (dict): Config of vote loss.
""" """
def __init__(self, def __init__(self,
...@@ -66,18 +66,19 @@ class VoteModule(nn.Module): ...@@ -66,18 +66,19 @@ class VoteModule(nn.Module):
"""forward. """forward.
Args: Args:
seed_points (torch.Tensor): (B, N, 3) coordinate of the seed seed_points (torch.Tensor): Coordinate of the seed
points. points in shape (B, N, 3).
seed_feats (torch.Tensor): (B, C, N) features of the seed points. seed_feats (torch.Tensor): Features of the seed points in shape
(B, C, N).
Returns: Returns:
tuple[torch.Tensor]: tuple[torch.Tensor]:
- vote_points: Voted xyz based on the seed points - vote_points: Voted xyz based on the seed points \
with shape (B, M, 3) M=num_seed*vote_per_seed. with shape (B, M, 3), ``M=num_seed*vote_per_seed``.
- vote_features: Voted features based on the seed points with - vote_features: Voted features based on the seed points with \
shape (B, C, M) where M=num_seed*vote_per_seed, shape (B, C, M) where ``M=num_seed*vote_per_seed``, \
C=vote_feature_dim. ``C=vote_feature_dim``.
""" """
batch_size, feat_channels, num_seed = seed_feats.shape batch_size, feat_channels, num_seed = seed_feats.shape
num_vote = num_seed * self.vote_per_seed num_vote = num_seed * self.vote_per_seed
...@@ -108,14 +109,14 @@ class VoteModule(nn.Module): ...@@ -108,14 +109,14 @@ class VoteModule(nn.Module):
"""Calculate loss of voting module. """Calculate loss of voting module.
Args: Args:
seed_points (torch.Tensor): coordinate of the seed points. seed_points (torch.Tensor): Coordinate of the seed points.
vote_points (torch.Tensor): coordinate of the vote points. vote_points (torch.Tensor): Coordinate of the vote points.
seed_indices (torch.Tensor): indices of seed points in raw points. seed_indices (torch.Tensor): Indices of seed points in raw points.
vote_targets_mask (torch.Tensor): mask of valid vote targets. vote_targets_mask (torch.Tensor): Mask of valid vote targets.
vote_targets (torch.Tensor): targets of votes. vote_targets (torch.Tensor): Targets of votes.
Returns: Returns:
torch.Tensor: weighted vote loss. torch.Tensor: Weighted vote loss.
""" """
batch_size, num_seed = seed_points.shape[:2] batch_size, num_seed = seed_points.shape[:2]
......
...@@ -73,10 +73,10 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -73,10 +73,10 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
by 3D box structures. by 3D box structures.
gt_labels (list[torch.LongTensor]): GT labels of each sample. gt_labels (list[torch.LongTensor]): GT labels of each sample.
gt_bboxes_ignore (list[torch.Tensor], optional): gt_bboxes_ignore (list[torch.Tensor], optional):
Specify which bounding. Ground truth boxes to be ignored.
Returns: Returns:
dict: losses from each head. dict[str, torch.Tensor]: losses from each head.
""" """
pass pass
......
...@@ -87,7 +87,7 @@ class PointwiseSemanticHead(nn.Module): ...@@ -87,7 +87,7 @@ class PointwiseSemanticHead(nn.Module):
gt_labels_3d (torch.Tensor): shape [box_num], class label of gt gt_labels_3d (torch.Tensor): shape [box_num], class label of gt
Returns: Returns:
tuple[torch.Tensor]: segmentation targets with shape [voxel_num] tuple[torch.Tensor]: segmentation targets with shape [voxel_num] \
part prediction targets with shape [voxel_num, 3] part prediction targets with shape [voxel_num, 3]
""" """
gt_bboxes_3d = gt_bboxes_3d.to(voxel_centers.device) gt_bboxes_3d = gt_bboxes_3d.to(voxel_centers.device)
...@@ -136,10 +136,10 @@ class PointwiseSemanticHead(nn.Module): ...@@ -136,10 +136,10 @@ class PointwiseSemanticHead(nn.Module):
Returns: Returns:
dict: prediction targets dict: prediction targets
- seg_targets (torch.Tensor): segmentation targets - seg_targets (torch.Tensor): Segmentation targets \
with shape [voxel_num] with shape [voxel_num].
- part_targets (torch.Tensor): part prediction targets - part_targets (torch.Tensor): Part prediction targets \
with shape [voxel_num, 3] with shape [voxel_num, 3].
""" """
batch_size = len(gt_labels_3d) batch_size = len(gt_labels_3d)
voxel_center_list = [] voxel_center_list = []
......
...@@ -96,7 +96,7 @@ def _fill_trainval_infos(lyft, ...@@ -96,7 +96,7 @@ def _fill_trainval_infos(lyft,
"""Generate the train/val infos from the raw data. """Generate the train/val infos from the raw data.
Args: Args:
lyft (:obj:``LyftDataset``): Dataset class in the Lyft dataset. lyft (:obj:`LyftDataset`): Dataset class in the Lyft dataset.
train_scenes (list[str]): Basic information of training scenes. train_scenes (list[str]): Basic information of training scenes.
val_scenes (list[str]): Basic information of validation scenes. val_scenes (list[str]): Basic information of validation scenes.
test (bool): Whether use the test mode. In the test mode, no test (bool): Whether use the test mode. In the test mode, no
......
...@@ -146,7 +146,7 @@ def _fill_trainval_infos(nusc, ...@@ -146,7 +146,7 @@ def _fill_trainval_infos(nusc,
"""Generate the train/val infos from the raw data. """Generate the train/val infos from the raw data.
Args: Args:
nusc (:obj:``NuScenes``): Dataset class in the nuScenes dataset. nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
train_scenes (list[str]): Basic information of training scenes. train_scenes (list[str]): Basic information of training scenes.
val_scenes (list[str]): Basic information of validation scenes. val_scenes (list[str]): Basic information of validation scenes.
test (bool): Whether use the test mode. In the test mode, no test (bool): Whether use the test mode. In the test mode, no
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment