"python/vscode:/vscode.git/clone" did not exist on "89cd923581fec16d70ed536eceac7212dc6e0898"
Commit 41978daf authored by zhangwenwei's avatar zhangwenwei
Browse files

Complete the docstrings of detector

parent 3d29ab20
...@@ -41,6 +41,7 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -41,6 +41,7 @@ class SingleStage3DDetector(Base3DDetector):
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize weights of detector."""
super(SingleStage3DDetector, self).init_weights(pretrained) super(SingleStage3DDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained) self.backbone.init_weights(pretrained=pretrained)
if self.with_neck: if self.with_neck:
...@@ -63,6 +64,7 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -63,6 +64,7 @@ class SingleStage3DDetector(Base3DDetector):
return x return x
def extract_feats(self, points, img_metas): def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
return [ return [
self.extract_feat(pts, img_meta) self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas) for pts, img_meta in zip(points, img_metas)
......
...@@ -84,6 +84,7 @@ class VoteNet(SingleStage3DDetector): ...@@ -84,6 +84,7 @@ class VoteNet(SingleStage3DDetector):
return bbox_results[0] return bbox_results[0]
def aug_test(self, points, img_metas, imgs=None, rescale=False): def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test with augmentation."""
points_cat = [torch.stack(pts) for pts in points] points_cat = [torch.stack(pts) for pts in points]
feats = self.extract_feats(points_cat, img_metas) feats = self.extract_feats(points_cat, img_metas)
......
...@@ -10,6 +10,7 @@ from .single_stage import SingleStage3DDetector ...@@ -10,6 +10,7 @@ from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class VoxelNet(SingleStage3DDetector): class VoxelNet(SingleStage3DDetector):
r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
def __init__(self, def __init__(self,
voxel_layer, voxel_layer,
...@@ -68,6 +69,21 @@ class VoxelNet(SingleStage3DDetector): ...@@ -68,6 +69,21 @@ class VoxelNet(SingleStage3DDetector):
gt_bboxes_3d, gt_bboxes_3d,
gt_labels_3d, gt_labels_3d,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
"""Training forward function.
Args:
points (list[torch.Tensor]): Point cloud of each sample.
img_metas (list[dict]): Meta information of each sample
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
Returns:
dict: Losses of each branch.
"""
x = self.extract_feat(points, img_metas) x = self.extract_feat(points, img_metas)
outs = self.bbox_head(x) outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas) loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
...@@ -76,6 +92,7 @@ class VoxelNet(SingleStage3DDetector): ...@@ -76,6 +92,7 @@ class VoxelNet(SingleStage3DDetector):
return losses return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False): def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function without augmentaiton."""
x = self.extract_feat(points, img_metas) x = self.extract_feat(points, img_metas)
outs = self.bbox_head(x) outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes( bbox_list = self.bbox_head.get_bboxes(
...@@ -87,6 +104,7 @@ class VoxelNet(SingleStage3DDetector): ...@@ -87,6 +104,7 @@ class VoxelNet(SingleStage3DDetector):
return bbox_results[0] return bbox_results[0]
def aug_test(self, points, img_metas, imgs=None, rescale=False): def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function with augmentaiton."""
feats = self.extract_feats(points, img_metas) feats = self.extract_feats(points, img_metas)
# only support aug_test for one sample # only support aug_test for one sample
......
...@@ -220,6 +220,7 @@ class PointFusion(nn.Module): ...@@ -220,6 +220,7 @@ class PointFusion(nn.Module):
# default init_weights for conv(msra) and norm in ConvModule # default init_weights for conv(msra) and norm in ConvModule
def init_weights(self): def init_weights(self):
"""Initialize the weights of modules."""
for m in self.modules(): for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)): if isinstance(m, (nn.Conv2d, nn.Linear)):
xavier_init(m, distribution='uniform') xavier_init(m, distribution='uniform')
...@@ -252,6 +253,17 @@ class PointFusion(nn.Module): ...@@ -252,6 +253,17 @@ class PointFusion(nn.Module):
return fuse_out return fuse_out
def obtain_mlvl_feats(self, img_feats, pts, img_metas): def obtain_mlvl_feats(self, img_feats, pts, img_metas):
"""Obtain multi-level features for each point.
Args:
img_feats (list(torch.Tensor)): Multi-scale image features produced
by image backbone in shape (N, C, H, W).
pts (list[torch.Tensor]): Points of each sample.
img_metas (list[dict]): Meta information for each sample.
Returns:
torch.Tensor: Corresponding image features of each point.
"""
if self.lateral_convs is not None: if self.lateral_convs is not None:
img_ins = [ img_ins = [
lateral_conv(img_feats[i]) lateral_conv(img_feats[i])
...@@ -277,6 +289,17 @@ class PointFusion(nn.Module): ...@@ -277,6 +289,17 @@ class PointFusion(nn.Module):
return img_pts return img_pts
def sample_single(self, img_feats, pts, img_meta): def sample_single(self, img_feats, pts, img_meta):
"""Sample features from single level image feature map.
Args:
img_feats (torch.Tensor): Image feature map in shape
(N, C, H, W).
pts (torch.Tensor): Points of a single sample.
img_meta (dict): Meta information of the single sample.
Returns:
torch.Tensor: Single level image features of each point.
"""
pcd_scale_factor = ( pcd_scale_factor = (
img_meta['pcd_scale_factor'] img_meta['pcd_scale_factor']
if 'pcd_scale_factor' in img_meta.keys() else 1) if 'pcd_scale_factor' in img_meta.keys() else 1)
......
...@@ -14,16 +14,16 @@ def chamfer_distance(src, ...@@ -14,16 +14,16 @@ def chamfer_distance(src,
"""Calculate Chamfer Distance of two sets. """Calculate Chamfer Distance of two sets.
Args: Args:
src (tensor): Source set with shape [B, N, C] to src (Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance. calculate Chamfer Distance.
dst (tensor): Destination set with shape [B, M, C] to dst (Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance. calculate Chamfer Distance.
src_weight (tensor or float): Weight of source loss. src_weight (Tensor or float): Weight of source loss.
dst_weight (tensor or float): Weight of destination loss. dst_weight (Tensor or float): Weight of destination loss.
criterion_mode (str): Criterion mode to calculate distance. criterion_mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2. The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses. reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean. The valid reduction method are 'none', 'sum' or 'mean'.
Returns: Returns:
tuple: Source and Destination loss with indices. tuple: Source and Destination loss with indices.
...@@ -103,6 +103,29 @@ class ChamferDistance(nn.Module): ...@@ -103,6 +103,29 @@ class ChamferDistance(nn.Module):
reduction_override=None, reduction_override=None,
return_indices=False, return_indices=False,
**kwargs): **kwargs):
"""Forward function of loss calculation.
Args:
source (Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
target (Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (Tensor | float, optional): Weight of source loss.
Defaults to 1.0.
dst_weight (Tensor | float, optional): Weight of destination loss.
Defaults to 1.0.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
return_indices (bool, optional): Whether to return indices.
Defaults to False.
Returns:
tuple[torch.Tensor]: If ``return_indices=True``, return losses of
source and target with their corresponding indices in the order
of (loss_source, loss_target, indices1, indices2). If
``return_indices=False``, return (loss_source, loss_target).
"""
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
reduction_override if reduction_override else self.reduction) reduction_override if reduction_override else self.reduction)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment