Commit 32f3955c authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'update_head_docstrings' into 'master'

update head docstrings

See merge request open-mmlab/mmdet.3d!119
parents ea779429 d891d5c0
...@@ -169,7 +169,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -169,7 +169,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
device (str): device of current module device (str): device of current module
Returns: Returns:
tuple: anchors of each image, valid flags of each image list[list[torch.Tensor]]: anchors of each image, valid flags
of each image
""" """
num_imgs = len(input_metas) num_imgs = len(input_metas)
# since feature map sizes of all images are the same, we only compute # since feature map sizes of all images are the same, we only compute
...@@ -253,7 +254,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -253,7 +254,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
dimension is rotation dimension dimension is rotation dimension
Returns: Returns:
tuple: (boxes1, boxes2) whose 7th dimensions are changed tuple[torch.Tensor]: boxes1 and boxes2 whose 7th dimensions
are changed
""" """
rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos( rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
boxes2[..., 6:7]) boxes2[..., 6:7])
...@@ -289,6 +291,10 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -289,6 +291,10 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
Returns: Returns:
dict: Contain class, bbox and direction losses of each level. dict: Contain class, bbox and direction losses of each level.
- loss_cls (list[torch.Tensor]): class losses
- loss_bbox (list[torch.Tensor]): bbox losses
- loss_dir (list[torch.Tensor]): direction losses
""" """
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels assert len(featmap_sizes) == self.anchor_generator.num_levels
...@@ -404,6 +410,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -404,6 +410,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
Returns: Returns:
tuple: Contain predictions of single batch. tuple: Contain predictions of single batch.
- bboxes (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes. - bboxes (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores (torch.Tensor): Class score of each bbox. - scores (torch.Tensor): Class score of each bbox.
- labels (torch.Tensor): Label of each bbox. - labels (torch.Tensor): Label of each bbox.
......
...@@ -63,6 +63,9 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -63,6 +63,9 @@ class FreeAnchor3DHead(Anchor3DHead):
Returns: Returns:
dict: Loss items. dict: Loss items.
- positive_bag_loss (torch.Tensor): Loss of positive samples.
- negative_bag_loss (torch.Tensor): Loss of negative samples.
""" """
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels assert len(featmap_sizes) == self.anchor_generator.num_levels
......
...@@ -121,6 +121,7 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -121,6 +121,7 @@ class PartA2RPNHead(Anchor3DHead):
Returns: Returns:
dict: Predictions of single batch. Contain the keys: dict: Predictions of single batch. Contain the keys:
- boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes. - boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores_3d (torch.Tensor): Score of each bbox. - scores_3d (torch.Tensor): Score of each bbox.
- labels_3d (torch.Tensor): Label of each bbox. - labels_3d (torch.Tensor): Label of each bbox.
...@@ -217,6 +218,7 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -217,6 +218,7 @@ class PartA2RPNHead(Anchor3DHead):
Returns: Returns:
dict: Predictions of single batch. Contain the keys: dict: Predictions of single batch. Contain the keys:
- boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes. - boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores_3d (torch.Tensor): Score of each bbox. - scores_3d (torch.Tensor): Score of each bbox.
- labels_3d (torch.Tensor): Label of each bbox. - labels_3d (torch.Tensor): Label of each bbox.
......
...@@ -30,7 +30,11 @@ class AnchorTrainMixin(object): ...@@ -30,7 +30,11 @@ class AnchorTrainMixin(object):
sampling (bool): Whether to sample anchors. sampling (bool): Whether to sample anchors.
Returns: Returns:
tuple: Anchor targets. tuple (list, list, list, list, list, list, int, int):
Anchor targets, including labels, label weights,
bbox targets, bbox weights, direction targets,
direction weights, number of postive anchors and
number of negative anchors.
""" """
num_imgs = len(input_metas) num_imgs = len(input_metas)
assert len(anchor_list) == num_imgs assert len(anchor_list) == num_imgs
...@@ -105,7 +109,7 @@ class AnchorTrainMixin(object): ...@@ -105,7 +109,7 @@ class AnchorTrainMixin(object):
sampling (bool): Whether to sample anchors. sampling (bool): Whether to sample anchors.
Returns: Returns:
tuple: Anchor targets. tuple[torch.Tensor]: Anchor targets.
""" """
if isinstance(self.bbox_assigner, list): if isinstance(self.bbox_assigner, list):
feat_size = anchors.size(0) * anchors.size(1) * anchors.size(2) feat_size = anchors.size(0) * anchors.size(1) * anchors.size(2)
...@@ -194,7 +198,7 @@ class AnchorTrainMixin(object): ...@@ -194,7 +198,7 @@ class AnchorTrainMixin(object):
sampling (bool): Whether to sample anchors. sampling (bool): Whether to sample anchors.
Returns: Returns:
tuple: Anchor targets. tuple[torch.Tensor]: Anchor targets.
""" """
anchors = anchors.reshape(-1, anchors.size(-1)) anchors = anchors.reshape(-1, anchors.size(-1))
num_valid_anchors = anchors.shape[0] num_valid_anchors = anchors.shape[0]
......
...@@ -305,7 +305,7 @@ class VoteHead(nn.Module): ...@@ -305,7 +305,7 @@ class VoteHead(nn.Module):
bbox_preds (torch.Tensor): Bbox predictions of vote head. bbox_preds (torch.Tensor): Bbox predictions of vote head.
Returns: Returns:
tuple: Targets of vote head. tuple[torch.Tensor]: Targets of vote head.
""" """
# find empty example # find empty example
valid_gt_masks = list() valid_gt_masks = list()
...@@ -391,7 +391,7 @@ class VoteHead(nn.Module): ...@@ -391,7 +391,7 @@ class VoteHead(nn.Module):
vote aggregation layer. vote aggregation layer.
Returns: Returns:
tuple: Targets of vote head. tuple[torch.Tensor]: Targets of vote head.
""" """
assert self.bbox_coder.with_rot or pts_semantic_mask is not None assert self.bbox_coder.with_rot or pts_semantic_mask is not None
......
...@@ -71,8 +71,9 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -71,8 +71,9 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): gt_bboxes (list[:obj:`BaseInstance3DBoxes`]):
GT bboxes of each sample. The bboxes are encapsulated GT bboxes of each sample. The bboxes are encapsulated
by 3D box structures. by 3D box structures.
gt_labels (list[LongTensor]): GT labels of each sample. gt_labels (list[torch.LongTensor]): GT labels of each sample.
gt_bboxes_ignore (list[Tensor], optional): Specify which bounding. gt_bboxes_ignore (list[torch.Tensor], optional):
Specify which bounding.
Returns: Returns:
dict: losses from each head. dict: losses from each head.
......
...@@ -297,6 +297,10 @@ class PartA2BboxHead(nn.Module): ...@@ -297,6 +297,10 @@ class PartA2BboxHead(nn.Module):
Returns: Returns:
dict: Computed losses. dict: Computed losses.
- loss_cls (torch.Tensor): loss of classes.
- loss_bbox (torch.Tensor): loss of bboxes.
- loss_corner (torch.Tensor): loss of corners.
""" """
losses = dict() losses = dict()
rcnn_batch_size = cls_score.shape[0] rcnn_batch_size = cls_score.shape[0]
...@@ -359,7 +363,7 @@ class PartA2BboxHead(nn.Module): ...@@ -359,7 +363,7 @@ class PartA2BboxHead(nn.Module):
concat (bool): Whether to concatenate targets between batches. concat (bool): Whether to concatenate targets between batches.
Returns: Returns:
tuple: Targets of boxes and class prediction. tuple[torch.Tensor]: Targets of boxes and class prediction.
""" """
pos_bboxes_list = [res.pos_bboxes for res in sampling_results] pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
...@@ -402,7 +406,7 @@ class PartA2BboxHead(nn.Module): ...@@ -402,7 +406,7 @@ class PartA2BboxHead(nn.Module):
cfg (dict): Training configs. cfg (dict): Training configs.
Returns: Returns:
tuple: Target for positive boxes. tuple[torch.Tensor]: Target for positive boxes.
(label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights) bbox_weights)
""" """
...@@ -459,11 +463,11 @@ class PartA2BboxHead(nn.Module): ...@@ -459,11 +463,11 @@ class PartA2BboxHead(nn.Module):
"""Calculate corner loss of given boxes. """Calculate corner loss of given boxes.
Args: Args:
pred_bbox3d (FloatTensor): predicted boxes with shape (N, 7). pred_bbox3d (torch.FloatTensor): predicted boxes with shape (N, 7).
gt_bbox3d (FloatTensor): gt boxes with shape (N, 7). gt_bbox3d (torch.FloatTensor): gt boxes with shape (N, 7).
Returns: Returns:
FloatTensor: Calculated corner loss with shape (N). torch.FloatTensor: Calculated corner loss with shape (N).
""" """
assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0] assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0]
......
...@@ -57,6 +57,11 @@ class PointwiseSemanticHead(nn.Module): ...@@ -57,6 +57,11 @@ class PointwiseSemanticHead(nn.Module):
Returns: Returns:
dict: part features, segmentation and part predictions. dict: part features, segmentation and part predictions.
- seg_preds (torch.Tensor): segment predictions
- part_preds (torch.Tensor): part predictions
- part_feats (torch.Tensor): feature predictions
""" """
seg_preds = self.seg_cls_layer(x) # (N, 1) seg_preds = self.seg_cls_layer(x) # (N, 1)
part_preds = self.seg_reg_layer(x) # (N, 3) part_preds = self.seg_reg_layer(x) # (N, 3)
...@@ -83,7 +88,7 @@ class PointwiseSemanticHead(nn.Module): ...@@ -83,7 +88,7 @@ class PointwiseSemanticHead(nn.Module):
gt_labels_3d (torch.Tensor): shape [box_num], class label of gt gt_labels_3d (torch.Tensor): shape [box_num], class label of gt
Returns: Returns:
tuple : segmentation targets with shape [voxel_num] tuple[torch.Tensor]: segmentation targets with shape [voxel_num]
part prediction targets with shape [voxel_num, 3] part prediction targets with shape [voxel_num, 3]
""" """
gt_bboxes_3d = gt_bboxes_3d.to(voxel_centers.device) gt_bboxes_3d = gt_bboxes_3d.to(voxel_centers.device)
...@@ -130,8 +135,12 @@ class PointwiseSemanticHead(nn.Module): ...@@ -130,8 +135,12 @@ class PointwiseSemanticHead(nn.Module):
gt_labels_3d (list[torch.Tensor]): list of GT labels. gt_labels_3d (list[torch.Tensor]): list of GT labels.
Returns: Returns:
tuple : segmentation targets with shape [voxel_num] dict: prediction targets
part prediction targets with shape [voxel_num, 3]
- seg_targets (torch.Tensor): segmentation targets
with shape [voxel_num]
- part_targets (torch.Tensor): part prediction targets
with shape [voxel_num, 3]
""" """
batch_size = len(gt_labels_3d) batch_size = len(gt_labels_3d)
voxel_center_list = [] voxel_center_list = []
...@@ -151,10 +160,20 @@ class PointwiseSemanticHead(nn.Module): ...@@ -151,10 +160,20 @@ class PointwiseSemanticHead(nn.Module):
Args: Args:
semantic_results (dict): Results from semantic head. semantic_results (dict): Results from semantic head.
- seg_preds: segmentation predictions
- part_preds: part predictions
semantic_targets (dict): Targets of semantic results. semantic_targets (dict): Targets of semantic results.
- seg_preds: segmentation targets
- part_preds: part targets
Returns: Returns:
dict: loss of segmentation and part prediction. dict: loss of segmentation and part prediction.
- loss_seg (torch.Tensor): segmentation prediction loss
- loss_part (torch.Tensor): part prediction loss
""" """
seg_preds = semantic_results['seg_preds'] seg_preds = semantic_results['seg_preds']
part_preds = semantic_results['part_preds'] part_preds = semantic_results['part_preds']
......
...@@ -96,6 +96,9 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -96,6 +96,9 @@ class PartAggregationROIHead(Base3DRoIHead):
Returns: Returns:
dict: losses from each head. dict: losses from each head.
- loss_semantic (torch.Tensor): loss of semantic head
- loss_bbox (torch.Tensor): loss of bboxes
""" """
losses = dict() losses = dict()
if self.with_semantic: if self.with_semantic:
......
...@@ -32,14 +32,14 @@ class Single3DRoIAwareExtractor(nn.Module): ...@@ -32,14 +32,14 @@ class Single3DRoIAwareExtractor(nn.Module):
"""Extract point-wise roi features. """Extract point-wise roi features.
Args: Args:
feats (FloatTensor): point-wise features with feats (torch.FloatTensor): point-wise features with
shape (batch, npoints, channels) for pooling shape (batch, npoints, channels) for pooling
coordinate (FloatTensor): coordinate of each point coordinate (torch.FloatTensor): coordinate of each point
batch_inds (longTensor): indicate the batch of each point batch_inds (torch.LongTensor): indicate the batch of each point
rois (FloatTensor): roi boxes with batch indices rois (torch.FloatTensor): roi boxes with batch indices
Returns: Returns:
FloatTensor: pooled features torch.FloatTensor: pooled features
""" """
pooled_roi_feats = [] pooled_roi_feats = []
for batch_idx in range(int(batch_inds.max()) + 1): for batch_idx in range(int(batch_inds.max()) + 1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment