Commit d891d5c0 authored by wangtai's avatar wangtai Committed by zhangwenwei
Browse files

Update mmdet3d/models/dense_heads/anchor3d_head.py,...

Update mmdet3d/models/dense_heads/anchor3d_head.py, mmdet3d/models/dense_heads/free_anchor3d_head.py, mmdet3d/models/dense_heads/parta2_rpn_head.py, mmdet3d/models/dense_heads/train_mixins.py, mmdet3d/models/dense_heads/vote_head.py, mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py, mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py, mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py, mmdet3d/models/roi_heads/base_3droi_head.py, mmdet3d/models/roi_heads/part_aggregation_roi_head.py files
parent ea779429
......@@ -169,7 +169,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
device (str): device of current module
Returns:
tuple: anchors of each image, valid flags of each image
list[list[torch.Tensor]]: anchors of each image, valid flags
of each image
"""
num_imgs = len(input_metas)
# since feature map sizes of all images are the same, we only compute
......@@ -253,7 +254,8 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
dimension is rotation dimension
Returns:
tuple: (boxes1, boxes2) whose 7th dimensions are changed
tuple[torch.Tensor]: boxes1 and boxes2 whose 7th dimensions
are changed
"""
rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos(
boxes2[..., 6:7])
......@@ -289,6 +291,10 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
Returns:
dict: Contain class, bbox and direction losses of each level.
- loss_cls (list[torch.Tensor]): class losses
- loss_bbox (list[torch.Tensor]): bbox losses
- loss_dir (list[torch.Tensor]): direction losses
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
......@@ -404,6 +410,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
Returns:
tuple: Contain predictions of single batch.
- bboxes (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores (torch.Tensor): Class score of each bbox.
- labels (torch.Tensor): Label of each bbox.
......
......@@ -63,6 +63,9 @@ class FreeAnchor3DHead(Anchor3DHead):
Returns:
dict: Loss items.
- positive_bag_loss (torch.Tensor): Loss of positive samples.
- negative_bag_loss (torch.Tensor): Loss of negative samples.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
......
......@@ -121,6 +121,7 @@ class PartA2RPNHead(Anchor3DHead):
Returns:
dict: Predictions of single batch. Contain the keys:
- boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores_3d (torch.Tensor): Score of each bbox.
- labels_3d (torch.Tensor): Label of each bbox.
......@@ -217,6 +218,7 @@ class PartA2RPNHead(Anchor3DHead):
Returns:
dict: Predictions of single batch. Contain the keys:
- boxes_3d (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores_3d (torch.Tensor): Score of each bbox.
- labels_3d (torch.Tensor): Label of each bbox.
......
......@@ -30,7 +30,11 @@ class AnchorTrainMixin(object):
sampling (bool): Whether to sample anchors.
Returns:
tuple: Anchor targets.
tuple (list, list, list, list, list, list, int, int):
Anchor targets, including labels, label weights,
bbox targets, bbox weights, direction targets,
direction weights, number of postive anchors and
number of negative anchors.
"""
num_imgs = len(input_metas)
assert len(anchor_list) == num_imgs
......@@ -105,7 +109,7 @@ class AnchorTrainMixin(object):
sampling (bool): Whether to sample anchors.
Returns:
tuple: Anchor targets.
tuple[torch.Tensor]: Anchor targets.
"""
if isinstance(self.bbox_assigner, list):
feat_size = anchors.size(0) * anchors.size(1) * anchors.size(2)
......@@ -194,7 +198,7 @@ class AnchorTrainMixin(object):
sampling (bool): Whether to sample anchors.
Returns:
tuple: Anchor targets.
tuple[torch.Tensor]: Anchor targets.
"""
anchors = anchors.reshape(-1, anchors.size(-1))
num_valid_anchors = anchors.shape[0]
......
......@@ -305,7 +305,7 @@ class VoteHead(nn.Module):
bbox_preds (torch.Tensor): Bbox predictions of vote head.
Returns:
tuple: Targets of vote head.
tuple[torch.Tensor]: Targets of vote head.
"""
# find empty example
valid_gt_masks = list()
......@@ -391,7 +391,7 @@ class VoteHead(nn.Module):
vote aggregation layer.
Returns:
tuple: Targets of vote head.
tuple[torch.Tensor]: Targets of vote head.
"""
assert self.bbox_coder.with_rot or pts_semantic_mask is not None
......
......@@ -71,8 +71,9 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]):
GT bboxes of each sample. The bboxes are encapsulated
by 3D box structures.
gt_labels (list[LongTensor]): GT labels of each sample.
gt_bboxes_ignore (list[Tensor], optional): Specify which bounding.
gt_labels (list[torch.LongTensor]): GT labels of each sample.
gt_bboxes_ignore (list[torch.Tensor], optional):
Specify which bounding.
Returns:
dict: losses from each head.
......
......@@ -297,6 +297,10 @@ class PartA2BboxHead(nn.Module):
Returns:
dict: Computed losses.
- loss_cls (torch.Tensor): loss of classes.
- loss_bbox (torch.Tensor): loss of bboxes.
- loss_corner (torch.Tensor): loss of corners.
"""
losses = dict()
rcnn_batch_size = cls_score.shape[0]
......@@ -359,7 +363,7 @@ class PartA2BboxHead(nn.Module):
concat (bool): Whether to concatenate targets between batches.
Returns:
tuple: Targets of boxes and class prediction.
tuple[torch.Tensor]: Targets of boxes and class prediction.
"""
pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
......@@ -402,7 +406,7 @@ class PartA2BboxHead(nn.Module):
cfg (dict): Training configs.
Returns:
tuple: Target for positive boxes.
tuple[torch.Tensor]: Target for positive boxes.
(label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights)
"""
......@@ -459,11 +463,11 @@ class PartA2BboxHead(nn.Module):
"""Calculate corner loss of given boxes.
Args:
pred_bbox3d (FloatTensor): predicted boxes with shape (N, 7).
gt_bbox3d (FloatTensor): gt boxes with shape (N, 7).
pred_bbox3d (torch.FloatTensor): predicted boxes with shape (N, 7).
gt_bbox3d (torch.FloatTensor): gt boxes with shape (N, 7).
Returns:
FloatTensor: Calculated corner loss with shape (N).
torch.FloatTensor: Calculated corner loss with shape (N).
"""
assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0]
......
......@@ -57,6 +57,11 @@ class PointwiseSemanticHead(nn.Module):
Returns:
dict: part features, segmentation and part predictions.
- seg_preds (torch.Tensor): segment predictions
- part_preds (torch.Tensor): part predictions
- part_feats (torch.Tensor): feature predictions
"""
seg_preds = self.seg_cls_layer(x) # (N, 1)
part_preds = self.seg_reg_layer(x) # (N, 3)
......@@ -83,7 +88,7 @@ class PointwiseSemanticHead(nn.Module):
gt_labels_3d (torch.Tensor): shape [box_num], class label of gt
Returns:
tuple : segmentation targets with shape [voxel_num]
tuple[torch.Tensor]: segmentation targets with shape [voxel_num]
part prediction targets with shape [voxel_num, 3]
"""
gt_bboxes_3d = gt_bboxes_3d.to(voxel_centers.device)
......@@ -130,8 +135,12 @@ class PointwiseSemanticHead(nn.Module):
gt_labels_3d (list[torch.Tensor]): list of GT labels.
Returns:
tuple : segmentation targets with shape [voxel_num]
part prediction targets with shape [voxel_num, 3]
dict: prediction targets
- seg_targets (torch.Tensor): segmentation targets
with shape [voxel_num]
- part_targets (torch.Tensor): part prediction targets
with shape [voxel_num, 3]
"""
batch_size = len(gt_labels_3d)
voxel_center_list = []
......@@ -151,10 +160,20 @@ class PointwiseSemanticHead(nn.Module):
Args:
semantic_results (dict): Results from semantic head.
- seg_preds: segmentation predictions
- part_preds: part predictions
semantic_targets (dict): Targets of semantic results.
- seg_preds: segmentation targets
- part_preds: part targets
Returns:
dict: loss of segmentation and part prediction.
- loss_seg (torch.Tensor): segmentation prediction loss
- loss_part (torch.Tensor): part prediction loss
"""
seg_preds = semantic_results['seg_preds']
part_preds = semantic_results['part_preds']
......
......@@ -96,6 +96,9 @@ class PartAggregationROIHead(Base3DRoIHead):
Returns:
dict: losses from each head.
- loss_semantic (torch.Tensor): loss of semantic head
- loss_bbox (torch.Tensor): loss of bboxes
"""
losses = dict()
if self.with_semantic:
......
......@@ -32,14 +32,14 @@ class Single3DRoIAwareExtractor(nn.Module):
"""Extract point-wise roi features.
Args:
feats (FloatTensor): point-wise features with
feats (torch.FloatTensor): point-wise features with
shape (batch, npoints, channels) for pooling
coordinate (FloatTensor): coordinate of each point
batch_inds (longTensor): indicate the batch of each point
rois (FloatTensor): roi boxes with batch indices
coordinate (torch.FloatTensor): coordinate of each point
batch_inds (torch.LongTensor): indicate the batch of each point
rois (torch.FloatTensor): roi boxes with batch indices
Returns:
FloatTensor: pooled features
torch.FloatTensor: pooled features
"""
pooled_roi_feats = []
for batch_idx in range(int(batch_inds.max()) + 1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment