You need to sign in or sign up before continuing.
Commit 219deef1 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'complete-detector-doc' into 'master'

Complete the docstrings of detector

See merge request open-mmlab/mmdet.3d!130
parents 3d29ab20 41978daf
...@@ -36,7 +36,7 @@ linting: ...@@ -36,7 +36,7 @@ linting:
- echo "Start testing..." - echo "Start testing..."
- coverage run --branch --source mmdet3d -m pytest tests/ - coverage run --branch --source mmdet3d -m pytest tests/
- coverage report -m - coverage report -m
- interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --exclude mmdet3d/ops --ignore-regex "__repr__" --fail-under 44 mmdet3d - interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --exclude mmdet3d/ops --ignore-regex "__repr__" --fail-under 80 mmdet3d
test:pytorch1.3-cuda10: test:pytorch1.3-cuda10:
image: $PYTORCH_IMAGE image: $PYTORCH_IMAGE
......
...@@ -102,7 +102,7 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -102,7 +102,7 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
base_xyz (torch.Tensor): coordinates of points. base_xyz (torch.Tensor): coordinates of points.
Returns: Returns:
dict: split results. dict[str, torch.Tensor]: split results.
""" """
results = {} results = {}
start, end = 0, 0 start, end = 0, 0
......
...@@ -221,7 +221,7 @@ def indoor_eval(gt_annos, ...@@ -221,7 +221,7 @@ def indoor_eval(gt_annos,
summary. See `mmdet.utils.print_log()` for details. Default: None. summary. See `mmdet.utils.print_log()` for details. Default: None.
Return: Return:
dict: Dict of results. dict[str, float]: Dict of results.
""" """
assert len(dt_annos) == len(gt_annos) assert len(dt_annos) == len(gt_annos)
pred = {} # map {class_id: pred} pred = {} # map {class_id: pred}
......
...@@ -468,7 +468,7 @@ def eval_class(gt_annos, ...@@ -468,7 +468,7 @@ def eval_class(gt_annos,
num_parts (int): A parameter for fast calculate algorithm num_parts (int): A parameter for fast calculate algorithm
Returns: Returns:
dict: recall, precision and aos dict[str, np.ndarray]: recall, precision and aos
""" """
assert len(gt_annos) == len(dt_annos) assert len(gt_annos) == len(dt_annos)
num_examples = len(gt_annos) num_examples = len(gt_annos)
......
...@@ -99,7 +99,7 @@ def lyft_eval(lyft, data_root, res_path, eval_set, output_dir, logger=None): ...@@ -99,7 +99,7 @@ def lyft_eval(lyft, data_root, res_path, eval_set, output_dir, logger=None):
related information during evaluation. Default: None. related information during evaluation. Default: None.
Returns: Returns:
dict: The metric dictionary recording the evaluation results. dict[str, float]: The evaluation results.
""" """
# evaluate by lyft metrics # evaluate by lyft metrics
gts = load_lyft_gts(lyft, data_root, eval_set, logger) gts = load_lyft_gts(lyft, data_root, eval_set, logger)
......
...@@ -18,6 +18,7 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg): ...@@ -18,6 +18,7 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):
Returns: Returns:
dict: bbox results in cpu mode, containing the merged results dict: bbox results in cpu mode, containing the merged results
- boxes_3d (:obj:`BaseInstance3DBoxes`): merged detection bbox - boxes_3d (:obj:`BaseInstance3DBoxes`): merged detection bbox
- scores_3d (torch.Tensor): merged detection scores - scores_3d (torch.Tensor): merged detection scores
- labels_3d (torch.Tensor): merged predicted box labels - labels_3d (torch.Tensor): merged predicted box labels
......
...@@ -313,7 +313,7 @@ class KittiDataset(Custom3DDataset): ...@@ -313,7 +313,7 @@ class KittiDataset(Custom3DDataset):
Default: None. Default: None.
Returns: Returns:
dict[str: float]: results of each evaluation metric dict[str, float]: results of each evaluation metric
""" """
result_files, tmp_dir = self.format_results(results, pklfile_prefix) result_files, tmp_dir = self.format_results(results, pklfile_prefix)
from mmdet3d.core.evaluation import kitti_eval from mmdet3d.core.evaluation import kitti_eval
......
...@@ -378,7 +378,7 @@ class LyftDataset(Custom3DDataset): ...@@ -378,7 +378,7 @@ class LyftDataset(Custom3DDataset):
Default: None. Default: None.
Returns: Returns:
dict[str: float] dict[str, float]: Evaluation results.
""" """
result_files, tmp_dir = self.format_results(results, jsonfile_prefix, result_files, tmp_dir = self.format_results(results, jsonfile_prefix,
csv_savepath) csv_savepath)
......
...@@ -61,9 +61,19 @@ class NoStemRegNet(RegNet): ...@@ -61,9 +61,19 @@ class NoStemRegNet(RegNet):
super(NoStemRegNet, self).__init__(arch, **kwargs) super(NoStemRegNet, self).__init__(arch, **kwargs)
def _make_stem_layer(self, in_channels, base_channels): def _make_stem_layer(self, in_channels, base_channels):
"""Override the original function that do not initialize a stem layer
since 3D detector's voxel encoder works like a stem layer."""
return return
def forward(self, x): def forward(self, x):
"""Forward function of backbone.
Args:
x (torch.Tensor): Features in shape (N, C, H, W).
Returns:
tuple[torch.Tensor]: Multi-scale features.
"""
outs = [] outs = []
for i, layer_name in enumerate(self.res_layers): for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name) res_layer = getattr(self, layer_name)
......
...@@ -83,6 +83,7 @@ class PointNet2SASSG(nn.Module): ...@@ -83,6 +83,7 @@ class PointNet2SASSG(nn.Module):
fp_target_channel = skip_channel_list.pop() fp_target_channel = skip_channel_list.pop()
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers # Do not initialize the conv layers
# to follow the original implementation # to follow the original implementation
if isinstance(pretrained, str): if isinstance(pretrained, str):
...@@ -118,7 +119,8 @@ class PointNet2SASSG(nn.Module): ...@@ -118,7 +119,8 @@ class PointNet2SASSG(nn.Module):
with shape (B, N, 3 + input_feature_dim). with shape (B, N, 3 + input_feature_dim).
Returns: Returns:
dict: outputs after SA and FP modules. dict[str, list[Tensor]]: outputs after SA and FP modules.
- fp_xyz (list[Tensor]): contains the coordinates of - fp_xyz (list[Tensor]): contains the coordinates of
each fp features. each fp features.
- fp_features (list[Tensor]): contains the features - fp_features (list[Tensor]): contains the features
......
...@@ -62,6 +62,7 @@ class SECOND(nn.Module): ...@@ -62,6 +62,7 @@ class SECOND(nn.Module):
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize weights of the 2D backbone."""
# Do not initialize the conv layers # Do not initialize the conv layers
# to follow the original implementation # to follow the original implementation
if isinstance(pretrained, str): if isinstance(pretrained, str):
...@@ -70,6 +71,14 @@ class SECOND(nn.Module): ...@@ -70,6 +71,14 @@ class SECOND(nn.Module):
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
def forward(self, x): def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input with shape (N, C, H, W).
Returns:
tuple[torch.Tensor]: Multi-scale features.
"""
outs = [] outs = []
for i in range(len(self.blocks)): for i in range(len(self.blocks)):
x = self.blocks[i](x) x = self.blocks[i](x)
......
...@@ -4,40 +4,50 @@ from .registry import FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS ...@@ -4,40 +4,50 @@ from .registry import FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS
def build_backbone(cfg): def build_backbone(cfg):
"""Build backbone."""
return build(cfg, BACKBONES) return build(cfg, BACKBONES)
def build_neck(cfg): def build_neck(cfg):
"""Build neck."""
return build(cfg, NECKS) return build(cfg, NECKS)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
"""Build RoI feature extractor."""
return build(cfg, ROI_EXTRACTORS) return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg): def build_shared_head(cfg):
"""Build shared head of detector."""
return build(cfg, SHARED_HEADS) return build(cfg, SHARED_HEADS)
def build_head(cfg): def build_head(cfg):
"""Build head."""
return build(cfg, HEADS) return build(cfg, HEADS)
def build_loss(cfg): def build_loss(cfg):
"""Build loss function."""
return build(cfg, LOSSES) return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_voxel_encoder(cfg): def build_voxel_encoder(cfg):
"""Build voxel encoder."""
return build(cfg, VOXEL_ENCODERS) return build(cfg, VOXEL_ENCODERS)
def build_middle_encoder(cfg): def build_middle_encoder(cfg):
"""Build middle level encoder."""
return build(cfg, MIDDLE_ENCODERS) return build(cfg, MIDDLE_ENCODERS)
def build_fusion_layer(cfg): def build_fusion_layer(cfg):
"""Build fusion layer."""
return build(cfg, FUSION_LAYERS) return build(cfg, FUSION_LAYERS)
...@@ -102,6 +102,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -102,6 +102,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self._init_assigner_sampler() self._init_assigner_sampler()
def _init_assigner_sampler(self): def _init_assigner_sampler(self):
"""Initialize the target assigner and sampler of the head."""
if self.train_cfg is None: if self.train_cfg is None:
return return
...@@ -117,6 +118,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -117,6 +118,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
] ]
def _init_layers(self): def _init_layers(self):
"""Initialize neural network layers of the head."""
self.cls_out_channels = self.num_anchors * self.num_classes self.cls_out_channels = self.num_anchors * self.num_classes
self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1) self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.feat_channels, self.conv_reg = nn.Conv2d(self.feat_channels,
...@@ -126,6 +128,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -126,6 +128,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self.num_anchors * 2, 1) self.num_anchors * 2, 1)
def init_weights(self): def init_weights(self):
"""Initialize the weights of head."""
bias_cls = bias_init_with_prob(0.01) bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_cls, std=0.01, bias=bias_cls) normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01) normal_init(self.conv_reg, std=0.01)
...@@ -290,11 +293,13 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -290,11 +293,13 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
which bounding. which bounding.
Returns: Returns:
dict: Contain class, bbox and direction losses of each level. dict[str, list[torch.Tensor]]: Classification, bbox, and direction
losses of each level.
- loss_cls (list[torch.Tensor]): class losses - loss_cls (list[torch.Tensor]): Classification losses.
- loss_bbox (list[torch.Tensor]): bbox losses - loss_bbox (list[torch.Tensor]): Box regression losses.
- loss_dir (list[torch.Tensor]): direction losses - loss_dir (list[torch.Tensor]): Direction classification
losses.
""" """
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels assert len(featmap_sizes) == self.anchor_generator.num_levels
......
...@@ -62,7 +62,7 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -62,7 +62,7 @@ class FreeAnchor3DHead(Anchor3DHead):
Ground truth boxes that should be ignored. Defaults to None. Ground truth boxes that should be ignored. Defaults to None.
Returns: Returns:
dict: Loss items. dict[str, torch.Tensor]: Loss items.
- positive_bag_loss (torch.Tensor): Loss of positive samples. - positive_bag_loss (torch.Tensor): Loss of positive samples.
- negative_bag_loss (torch.Tensor): Loss of negative samples. - negative_bag_loss (torch.Tensor): Loss of negative samples.
......
...@@ -89,6 +89,29 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -89,6 +89,29 @@ class PartA2RPNHead(Anchor3DHead):
gt_labels, gt_labels,
input_metas, input_metas,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
"""Calculate losses.
Args:
cls_scores (list[torch.Tensor]): Multi-level class scores.
bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions.
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Gt bboxes
of each sample.
gt_labels (list[torch.Tensor]): Gt labels of each sample.
input_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
Returns:
dict[str, list[torch.Tensor]]: Classification, bbox, and direction
losses of each level.
- loss_rpn_cls (list[torch.Tensor]): Classification losses.
- loss_rpn_bbox (list[torch.Tensor]): Box regression losses.
- loss_rpn_dir (list[torch.Tensor]): Direction classification
losses.
"""
loss_dict = super().loss(cls_scores, bbox_preds, dir_cls_preds, loss_dict = super().loss(cls_scores, bbox_preds, dir_cls_preds,
gt_bboxes, gt_labels, input_metas, gt_bboxes, gt_labels, input_metas,
gt_bboxes_ignore) gt_bboxes_ignore)
......
...@@ -6,6 +6,7 @@ from mmdet.core import images_to_levels, multi_apply ...@@ -6,6 +6,7 @@ from mmdet.core import images_to_levels, multi_apply
class AnchorTrainMixin(object): class AnchorTrainMixin(object):
"""Mixin class for target assigning of dense heads."""
def anchor_target_3d(self, def anchor_target_3d(self,
anchor_list, anchor_list,
......
...@@ -107,6 +107,7 @@ class VoteHead(nn.Module): ...@@ -107,6 +107,7 @@ class VoteHead(nn.Module):
nn.Conv1d(prev_channel, conv_out_channel, 1)) nn.Conv1d(prev_channel, conv_out_channel, 1))
def init_weights(self): def init_weights(self):
"""Initialize weights of VoteHead."""
pass pass
def forward(self, feat_dict, sample_mod): def forward(self, feat_dict, sample_mod):
......
...@@ -62,7 +62,7 @@ class Base3DDetector(BaseDetector): ...@@ -62,7 +62,7 @@ class Base3DDetector(BaseDetector):
"""Results visualization. """Results visualization.
Args: Args:
data (dict): Input points and info. data (dict): Input points and the information of the sample.
result (dict): Prediction results. result (dict): Prediction results.
out_dir (str): Output directory of visualization result. out_dir (str): Output directory of visualization result.
""" """
......
...@@ -285,6 +285,21 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -285,6 +285,21 @@ class MVXTwoStageDetector(Base3DDetector):
gt_labels_3d, gt_labels_3d,
img_metas, img_metas,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
"""Forward function for point cloud branch.
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
Returns:
dict: Losses of each branch.
"""
outs = self.pts_bbox_head(pts_feats) outs = self.pts_bbox_head(pts_feats)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas) loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.pts_bbox_head.loss( losses = self.pts_bbox_head.loss(
...@@ -299,6 +314,25 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -299,6 +314,25 @@ class MVXTwoStageDetector(Base3DDetector):
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
proposals=None, proposals=None,
**kwargs): **kwargs):
"""Forward function for image branch.
This function works similar to the forward function of Faster R-CNN.
Args:
x (list[torch.Tensor]): Image features of shape (B, C, H, W)
of multiple levels.
img_metas (list[dict]): Meta information of images.
gt_bboxes (list[torch.Tensor]): Ground truth boxes of each image
sample.
gt_labels (list[torch.Tensor]): Ground truth labels of boxes.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
proposals (list[torch.Tensor], optional): Proposals of each sample.
Defaults to None.
Returns:
dict: Losses of each branch.
"""
losses = dict() losses = dict()
# RPN forward and loss # RPN forward and loss
if self.with_img_rpn: if self.with_img_rpn:
...@@ -338,12 +372,14 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -338,12 +372,14 @@ class MVXTwoStageDetector(Base3DDetector):
x, proposal_list, img_metas, rescale=rescale) x, proposal_list, img_metas, rescale=rescale)
def simple_test_rpn(self, x, img_metas, rpn_test_cfg): def simple_test_rpn(self, x, img_metas, rpn_test_cfg):
"""RPN test function."""
rpn_outs = self.img_rpn_head(x) rpn_outs = self.img_rpn_head(x)
proposal_inputs = rpn_outs + (img_metas, rpn_test_cfg) proposal_inputs = rpn_outs + (img_metas, rpn_test_cfg)
proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs) proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs)
return proposal_list return proposal_list
def simple_test_pts(self, x, img_metas, rescale=False): def simple_test_pts(self, x, img_metas, rescale=False):
"""Test function of point cloud branch."""
outs = self.pts_bbox_head(x) outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes( bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale) *outs, img_metas, rescale=rescale)
...@@ -354,6 +390,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -354,6 +390,7 @@ class MVXTwoStageDetector(Base3DDetector):
return bbox_results[0] return bbox_results[0]
def simple_test(self, points, img_metas, img=None, rescale=False): def simple_test(self, points, img_metas, img=None, rescale=False):
"""Test function without augmentaiton."""
img_feats, pts_feats = self.extract_feat( img_feats, pts_feats = self.extract_feat(
points, img=img, img_metas=img_metas) points, img=img, img_metas=img_metas)
...@@ -369,6 +406,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -369,6 +406,7 @@ class MVXTwoStageDetector(Base3DDetector):
return bbox_list return bbox_list
def aug_test(self, points, img_metas, imgs=None, rescale=False): def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function with augmentaiton."""
img_feats, pts_feats = self.extract_feats(points, img_metas, imgs) img_feats, pts_feats = self.extract_feats(points, img_metas, imgs)
bbox_list = dict() bbox_list = dict()
...@@ -378,6 +416,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -378,6 +416,7 @@ class MVXTwoStageDetector(Base3DDetector):
return bbox_list return bbox_list
def extract_feats(self, points, img_metas, imgs=None): def extract_feats(self, points, img_metas, imgs=None):
"""Extract point and image features of multiple samples."""
if imgs is None: if imgs is None:
imgs = [None] * len(img_metas) imgs = [None] * len(img_metas)
img_feats, pts_feats = multi_apply(self.extract_feat, points, imgs, img_feats, pts_feats = multi_apply(self.extract_feat, points, imgs,
...@@ -385,6 +424,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -385,6 +424,7 @@ class MVXTwoStageDetector(Base3DDetector):
return img_feats, pts_feats return img_feats, pts_feats
def aug_test_pts(self, feats, img_metas, rescale=False): def aug_test_pts(self, feats, img_metas, rescale=False):
"""Test function of point cloud branch with augmentaiton."""
# only support aug_test for one sample # only support aug_test for one sample
aug_bboxes = [] aug_bboxes = []
for x, img_meta in zip(feats, img_metas): for x, img_meta in zip(feats, img_metas):
...@@ -406,7 +446,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -406,7 +446,7 @@ class MVXTwoStageDetector(Base3DDetector):
"""Results visualization. """Results visualization.
Args: Args:
data (dict): Input points and info. data (dict): Input points and the information of the sample.
result (dict): Prediction results. result (dict): Prediction results.
out_dir (str): Output directory of visualization result. out_dir (str): Output directory of visualization result.
""" """
......
...@@ -9,7 +9,7 @@ from .two_stage import TwoStage3DDetector ...@@ -9,7 +9,7 @@ from .two_stage import TwoStage3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class PartA2(TwoStage3DDetector): class PartA2(TwoStage3DDetector):
"""Part-A2 detector. r"""Part-A2 detector.
Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_ Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_
""" """
...@@ -39,6 +39,7 @@ class PartA2(TwoStage3DDetector): ...@@ -39,6 +39,7 @@ class PartA2(TwoStage3DDetector):
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_metas): def extract_feat(self, points, img_metas):
"""Extract features from points."""
voxel_dict = self.voxelize(points) voxel_dict = self.voxelize(points)
voxel_features = self.voxel_encoder(voxel_dict['voxels'], voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'], voxel_dict['num_points'],
...@@ -54,6 +55,7 @@ class PartA2(TwoStage3DDetector): ...@@ -54,6 +55,7 @@ class PartA2(TwoStage3DDetector):
@torch.no_grad() @torch.no_grad()
def voxelize(self, points): def voxelize(self, points):
"""Apply hard voxelization to points."""
voxels, coors, num_points, voxel_centers = [], [], [], [] voxels, coors, num_points, voxel_centers = [], [], [], []
for res in points: for res in points:
res_voxels, res_coors, res_num_points = self.voxel_layer(res) res_voxels, res_coors, res_num_points = self.voxel_layer(res)
...@@ -89,6 +91,21 @@ class PartA2(TwoStage3DDetector): ...@@ -89,6 +91,21 @@ class PartA2(TwoStage3DDetector):
gt_labels_3d, gt_labels_3d,
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
proposals=None): proposals=None):
"""Training forward function.
Args:
points (list[torch.Tensor]): Point cloud of each sample.
img_metas (list[dict]): Meta information of each sample
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
Returns:
dict: Losses of each branch.
"""
feats_dict, voxels_dict = self.extract_feat(points, img_metas) feats_dict, voxels_dict = self.extract_feat(points, img_metas)
losses = dict() losses = dict()
...@@ -117,6 +134,7 @@ class PartA2(TwoStage3DDetector): ...@@ -117,6 +134,7 @@ class PartA2(TwoStage3DDetector):
return losses return losses
def simple_test(self, points, img_metas, proposals=None, rescale=False): def simple_test(self, points, img_metas, proposals=None, rescale=False):
"""Test function without augmentaiton."""
feats_dict, voxels_dict = self.extract_feat(points, img_metas) feats_dict, voxels_dict = self.extract_feat(points, img_metas)
if self.with_rpn: if self.with_rpn:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment