Commit 80b39bd0 authored by zhangwenwei's avatar zhangwenwei
Browse files

Reformat docstrings in code

parent 64d7fbc2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
......
import os.path as osp from os import path as osp
from mmdet3d.core import Box3DMode, show_result from mmdet3d.core import Box3DMode, show_result
from mmdet.models.detectors import BaseDetector from mmdet.models.detectors import BaseDetector
class Base3DDetector(BaseDetector): class Base3DDetector(BaseDetector):
"""Base class for detectors""" """Base class for detectors."""
def forward_test(self, points, img_metas, img=None, **kwargs): def forward_test(self, points, img_metas, img=None, **kwargs):
""" """
...@@ -42,14 +42,15 @@ class Base3DDetector(BaseDetector): ...@@ -42,14 +42,15 @@ class Base3DDetector(BaseDetector):
return self.aug_test(points, img_metas, img, **kwargs) return self.aug_test(points, img_metas, img, **kwargs)
def forward(self, return_loss=True, **kwargs): def forward(self, return_loss=True, **kwargs):
""" """Calls either forward_train or forward_test depending on whether
Calls either forward_train or forward_test depending on whether return_loss=True.
return_loss=True. Note this setting will change the expected inputs.
When `return_loss=True`, img and img_metas are single-nested (i.e. Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested img_metas should be double nested (i.e. list[torch.Tensor],
(i.e. list[torch.Tensor], list[list[dict]]), with the outer list list[list[dict]]), with the outer list indicating test time
indicating test time augmentations. augmentations.
""" """
if return_loss: if return_loss:
return self.forward_train(**kwargs) return self.forward_train(**kwargs)
......
import torch import torch
import torch.nn.functional as F from torch.nn import functional as F
from mmdet.models import DETECTORS from mmdet.models import DETECTORS
from .voxelnet import VoxelNet from .voxelnet import VoxelNet
...@@ -7,8 +7,8 @@ from .voxelnet import VoxelNet ...@@ -7,8 +7,8 @@ from .voxelnet import VoxelNet
@DETECTORS.register_module() @DETECTORS.register_module()
class DynamicVoxelNet(VoxelNet): class DynamicVoxelNet(VoxelNet):
"""VoxelNet using `dynamic voxelization r"""VoxelNet using `dynamic voxelization <https://arxiv.org/abs/1910.06528>`_.
<https://arxiv.org/abs/1910.06528>`_.""" """
def __init__(self, def __init__(self,
voxel_layer, voxel_layer,
...@@ -33,7 +33,7 @@ class DynamicVoxelNet(VoxelNet): ...@@ -33,7 +33,7 @@ class DynamicVoxelNet(VoxelNet):
) )
def extract_feat(self, points, img_metas): def extract_feat(self, points, img_metas):
"""Extract features from points""" """Extract features from points."""
voxels, coors = self.voxelize(points) voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.voxel_encoder(voxels, coors) voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
batch_size = coors[-1, 0].item() + 1 batch_size = coors[-1, 0].item() + 1
...@@ -45,7 +45,14 @@ class DynamicVoxelNet(VoxelNet): ...@@ -45,7 +45,14 @@ class DynamicVoxelNet(VoxelNet):
@torch.no_grad() @torch.no_grad()
def voxelize(self, points): def voxelize(self, points):
"""Apply dynamic voxelization to points""" """Apply dynamic voxelization to points.
Args:
points (list[torch.Tensor]): Points of each sample.
Returns:
tuple[torch.Tensor]: Concatenated points and coordinates.
"""
coors = [] coors = []
# dynamic voxelization only provide a coors mapping # dynamic voxelization only provide a coors mapping
for res in points: for res in points:
......
import torch import torch
import torch.nn.functional as F from torch.nn import functional as F
from mmdet.models import DETECTORS from mmdet.models import DETECTORS
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
...@@ -7,6 +7,7 @@ from .mvx_two_stage import MVXTwoStageDetector ...@@ -7,6 +7,7 @@ from .mvx_two_stage import MVXTwoStageDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class MVXFasterRCNN(MVXTwoStageDetector): class MVXFasterRCNN(MVXTwoStageDetector):
"""Multi-modality VoxelNet using Faster R-CNN."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(MVXFasterRCNN, self).__init__(**kwargs) super(MVXFasterRCNN, self).__init__(**kwargs)
...@@ -14,12 +15,21 @@ class MVXFasterRCNN(MVXTwoStageDetector): ...@@ -14,12 +15,21 @@ class MVXFasterRCNN(MVXTwoStageDetector):
@DETECTORS.register_module() @DETECTORS.register_module()
class DynamicMVXFasterRCNN(MVXTwoStageDetector): class DynamicMVXFasterRCNN(MVXTwoStageDetector):
"""Multi-modality VoxelNet using Faster R-CNN and dynamic voxelization."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs) super(DynamicMVXFasterRCNN, self).__init__(**kwargs)
@torch.no_grad() @torch.no_grad()
def voxelize(self, points): def voxelize(self, points):
"""Apply dynamic voxelization to points.
Args:
points (list[torch.Tensor]): Points of each sample.
Returns:
tuple[torch.Tensor]: Concatenated points and coordinates.
"""
coors = [] coors = []
# dynamic voxelization only provide a coors mapping # dynamic voxelization only provide a coors mapping
for res in points: for res in points:
...@@ -34,6 +44,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector): ...@@ -34,6 +44,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
return points, coors_batch return points, coors_batch
def extract_pts_feat(self, points, img_feats, img_metas): def extract_pts_feat(self, points, img_feats, img_metas):
"""Extract point features."""
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, coors = self.voxelize(points) voxels, coors = self.voxelize(points)
......
import os.path as osp
import torch import torch
import torch.nn as nn from os import path as osp
import torch.nn.functional as F from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core import (Box3DMode, bbox3d2result, merge_aug_bboxes_3d, from mmdet3d.core import (Box3DMode, bbox3d2result, merge_aug_bboxes_3d,
show_result) show_result)
...@@ -15,6 +14,7 @@ from .base import Base3DDetector ...@@ -15,6 +14,7 @@ from .base import Base3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class MVXTwoStageDetector(Base3DDetector): class MVXTwoStageDetector(Base3DDetector):
"""Base class of Multi-modality VoxelNet."""
def __init__(self, def __init__(self,
pts_voxel_layer=None, pts_voxel_layer=None,
...@@ -69,6 +69,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -69,6 +69,7 @@ class MVXTwoStageDetector(Base3DDetector):
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize model weights."""
super(MVXTwoStageDetector, self).init_weights(pretrained) super(MVXTwoStageDetector, self).init_weights(pretrained)
if pretrained is None: if pretrained is None:
img_pretrained = None img_pretrained = None
...@@ -99,59 +100,72 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -99,59 +100,72 @@ class MVXTwoStageDetector(Base3DDetector):
@property @property
def with_img_shared_head(self): def with_img_shared_head(self):
"""bool: Whether the detector has a shared head in image branch."""
return hasattr(self, return hasattr(self,
'img_shared_head') and self.img_shared_head is not None 'img_shared_head') and self.img_shared_head is not None
@property @property
def with_pts_bbox(self): def with_pts_bbox(self):
"""bool: Whether the detector has a 3D box head."""
return hasattr(self, return hasattr(self,
'pts_bbox_head') and self.pts_bbox_head is not None 'pts_bbox_head') and self.pts_bbox_head is not None
@property @property
def with_img_bbox(self): def with_img_bbox(self):
"""bool: Whether the detector has a 2D image box head."""
return hasattr(self, return hasattr(self,
'img_bbox_head') and self.img_bbox_head is not None 'img_bbox_head') and self.img_bbox_head is not None
@property @property
def with_img_backbone(self): def with_img_backbone(self):
"""bool: Whether the detector has a 2D image backbone."""
return hasattr(self, 'img_backbone') and self.img_backbone is not None return hasattr(self, 'img_backbone') and self.img_backbone is not None
@property @property
def with_pts_backbone(self): def with_pts_backbone(self):
"""bool: Whether the detector has a 3D backbone."""
return hasattr(self, 'pts_backbone') and self.pts_backbone is not None return hasattr(self, 'pts_backbone') and self.pts_backbone is not None
@property @property
def with_fusion(self): def with_fusion(self):
"""bool: Whether the detector has a fusion layer."""
return hasattr(self, return hasattr(self,
'pts_fusion_layer') and self.fusion_layer is not None 'pts_fusion_layer') and self.fusion_layer is not None
@property @property
def with_img_neck(self): def with_img_neck(self):
"""bool: Whether the detector has a neck in image branch."""
return hasattr(self, 'img_neck') and self.img_neck is not None return hasattr(self, 'img_neck') and self.img_neck is not None
@property @property
def with_pts_neck(self): def with_pts_neck(self):
"""bool: Whether the detector has a neck in 3D detector branch."""
return hasattr(self, 'pts_neck') and self.pts_neck is not None return hasattr(self, 'pts_neck') and self.pts_neck is not None
@property @property
def with_img_rpn(self): def with_img_rpn(self):
"""bool: Whether the detector has a 2D RPN in image detector branch."""
return hasattr(self, 'img_rpn_head') and self.img_rpn_head is not None return hasattr(self, 'img_rpn_head') and self.img_rpn_head is not None
@property @property
def with_img_roi_head(self): def with_img_roi_head(self):
"""bool: Whether the detector has a RoI Head in image branch."""
return hasattr(self, 'img_roi_head') and self.img_roi_head is not None return hasattr(self, 'img_roi_head') and self.img_roi_head is not None
@property @property
def with_voxel_encoder(self): def with_voxel_encoder(self):
"""bool: Whether the detector has a voxel encoder."""
return hasattr(self, return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None 'voxel_encoder') and self.voxel_encoder is not None
@property @property
def with_middle_encoder(self): def with_middle_encoder(self):
"""bool: Whether the detector has a middle encoder."""
return hasattr(self, return hasattr(self,
'middle_encoder') and self.middle_encoder is not None 'middle_encoder') and self.middle_encoder is not None
def extract_img_feat(self, img, img_metas): def extract_img_feat(self, img, img_metas):
"""Extract features of images."""
if self.with_img_backbone: if self.with_img_backbone:
if img.dim() == 5 and img.size(0) == 1: if img.dim() == 5 and img.size(0) == 1:
img.squeeze_() img.squeeze_()
...@@ -166,6 +180,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -166,6 +180,7 @@ class MVXTwoStageDetector(Base3DDetector):
return img_feats return img_feats
def extract_pts_feat(self, pts, img_feats, img_metas): def extract_pts_feat(self, pts, img_feats, img_metas):
"""Extract features of points."""
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, num_points, coors = self.voxelize(pts) voxels, num_points, coors = self.voxelize(pts)
...@@ -179,12 +194,22 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -179,12 +194,22 @@ class MVXTwoStageDetector(Base3DDetector):
return x return x
def extract_feat(self, points, img, img_metas): def extract_feat(self, points, img, img_metas):
"""Extract features from images and points."""
img_feats = self.extract_img_feat(img, img_metas) img_feats = self.extract_img_feat(img, img_metas)
pts_feats = self.extract_pts_feat(points, img_feats, img_metas) pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
return (img_feats, pts_feats) return (img_feats, pts_feats)
@torch.no_grad() @torch.no_grad()
def voxelize(self, points): def voxelize(self, points):
"""Apply dynamic voxelization to points.
Args:
points (list[torch.Tensor]): Points of each sample.
Returns:
tuple[torch.Tensor]: Concatenated points, number of points
per voxel, and coordinates.
"""
voxels, coors, num_points = [], [], [] voxels, coors, num_points = [], [], []
for res in points: for res in points:
res_voxels, res_coors, res_num_points = self.pts_voxel_layer(res) res_voxels, res_coors, res_num_points = self.pts_voxel_layer(res)
...@@ -208,8 +233,33 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -208,8 +233,33 @@ class MVXTwoStageDetector(Base3DDetector):
gt_labels=None, gt_labels=None,
gt_bboxes=None, gt_bboxes=None,
img=None, img=None,
bboxes=None, proposals=None,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats, pts_feats = self.extract_feat( img_feats, pts_feats = self.extract_feat(
points, img=img, img_metas=img_metas) points, img=img, img_metas=img_metas)
losses = dict() losses = dict()
...@@ -225,8 +275,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -225,8 +275,7 @@ class MVXTwoStageDetector(Base3DDetector):
gt_bboxes=gt_bboxes, gt_bboxes=gt_bboxes,
gt_labels=gt_labels, gt_labels=gt_labels,
gt_bboxes_ignore=gt_bboxes_ignore, gt_bboxes_ignore=gt_bboxes_ignore,
bboxes=bboxes, proposals=proposals)
)
losses.update(losses_img) losses.update(losses_img)
return losses return losses
......
import torch import torch
import torch.nn.functional as F from torch.nn import functional as F
from mmdet3d.ops import Voxelization from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS from mmdet.models import DETECTORS
...@@ -9,7 +9,7 @@ from .two_stage import TwoStage3DDetector ...@@ -9,7 +9,7 @@ from .two_stage import TwoStage3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class PartA2(TwoStage3DDetector): class PartA2(TwoStage3DDetector):
"""Part-A2 detector """Part-A2 detector.
Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_ Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_
""" """
......
import torch.nn as nn from torch import nn as nn
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector from .base import Base3DDetector
...@@ -6,7 +6,7 @@ from .base import Base3DDetector ...@@ -6,7 +6,7 @@ from .base import Base3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class SingleStage3DDetector(Base3DDetector): class SingleStage3DDetector(Base3DDetector):
"""SingleStage3DDetector """SingleStage3DDetector.
This class serves as a base class for single-stage 3D detectors. This class serves as a base class for single-stage 3D detectors.
......
...@@ -4,11 +4,11 @@ from .base import Base3DDetector ...@@ -4,11 +4,11 @@ from .base import Base3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class TwoStage3DDetector(Base3DDetector, TwoStageDetector): class TwoStage3DDetector(Base3DDetector, TwoStageDetector):
"""Base class of two-stage 3D detector """Base class of two-stage 3D detector.
It inherits original ``:class:TwoStageDetector`` and It inherits original ``:class:TwoStageDetector`` and
``:class:Base3DDetector``. This class could serve as a base class for ``:class:Base3DDetector``. This class could serve as a base class for all
all two-stage 3D detectors. two-stage 3D detectors.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
......
import torch import torch
import torch.nn.functional as F from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.ops import Voxelization from mmdet3d.ops import Voxelization
...@@ -34,7 +34,7 @@ class VoxelNet(SingleStage3DDetector): ...@@ -34,7 +34,7 @@ class VoxelNet(SingleStage3DDetector):
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_metas): def extract_feat(self, points, img_metas):
"""Extract features from points""" """Extract features from points."""
voxels, num_points, coors = self.voxelize(points) voxels, num_points, coors = self.voxelize(points)
voxel_features = self.voxel_encoder(voxels, num_points, coors) voxel_features = self.voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0].item() + 1 batch_size = coors[-1, 0].item() + 1
...@@ -46,7 +46,7 @@ class VoxelNet(SingleStage3DDetector): ...@@ -46,7 +46,7 @@ class VoxelNet(SingleStage3DDetector):
@torch.no_grad() @torch.no_grad()
def voxelize(self, points): def voxelize(self, points):
"""Apply hard voxelization to points""" """Apply hard voxelization to points."""
voxels, coors, num_points = [], [], [] voxels, coors, num_points = [], [], []
for res in points: for res in points:
res_voxels, res_coors, res_num_points = self.voxel_layer(res) res_voxels, res_coors, res_num_points = self.voxel_layer(res)
......
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init from mmcv.cnn import ConvModule, xavier_init
from torch import nn as nn
from torch.nn import functional as F
from ..registry import FUSION_LAYERS from ..registry import FUSION_LAYERS
...@@ -23,7 +23,7 @@ def point_sample( ...@@ -23,7 +23,7 @@ def point_sample(
padding_mode='zeros', padding_mode='zeros',
align_corners=True, align_corners=True,
): ):
"""Obtain image features using points """Obtain image features using points.
Args: Args:
img_features (Tensor): 1xCxHxW image features img_features (Tensor): 1xCxHxW image features
...@@ -113,7 +113,7 @@ def point_sample( ...@@ -113,7 +113,7 @@ def point_sample(
@FUSION_LAYERS.register_module() @FUSION_LAYERS.register_module()
class PointFusion(nn.Module): class PointFusion(nn.Module):
"""Fuse image features from multi-scale features """Fuse image features from multi-scale features.
Args: Args:
img_channels (list[int] | int): Channels of image features. img_channels (list[int] | int): Channels of image features.
...@@ -225,7 +225,7 @@ class PointFusion(nn.Module): ...@@ -225,7 +225,7 @@ class PointFusion(nn.Module):
xavier_init(m, distribution='uniform') xavier_init(m, distribution='uniform')
def forward(self, img_feats, pts, pts_feats, img_metas): def forward(self, img_feats, pts, pts_feats, img_metas):
"""Forward function """Forward function.
Args: Args:
img_feats (list[Tensor]): img features img_feats (list[Tensor]): img features
......
import torch import torch
import torch.nn as nn from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from mmdet.models.builder import LOSSES from mmdet.models.builder import LOSSES
......
...@@ -23,7 +23,7 @@ class PointPillarsScatter(nn.Module): ...@@ -23,7 +23,7 @@ class PointPillarsScatter(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
def forward(self, voxel_features, coors, batch_size=None): def forward(self, voxel_features, coors, batch_size=None):
"""Foraward function to scatter features""" """Foraward function to scatter features."""
# TODO: rewrite the function in a batch manner # TODO: rewrite the function in a batch manner
# no need to deal with different batch cases # no need to deal with different batch cases
if batch_size is not None: if batch_size is not None:
...@@ -32,7 +32,7 @@ class PointPillarsScatter(nn.Module): ...@@ -32,7 +32,7 @@ class PointPillarsScatter(nn.Module):
return self.forward_single(voxel_features, coors) return self.forward_single(voxel_features, coors)
def forward_single(self, voxel_features, coors): def forward_single(self, voxel_features, coors):
"""Scatter features of single sample """Scatter features of single sample.
Args: Args:
voxel_features (torch.Tensor): Voxel features in shape (N, M, C). voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
...@@ -56,7 +56,7 @@ class PointPillarsScatter(nn.Module): ...@@ -56,7 +56,7 @@ class PointPillarsScatter(nn.Module):
return [canvas] return [canvas]
def forward_batch(self, voxel_features, coors, batch_size): def forward_batch(self, voxel_features, coors, batch_size):
"""Scatter features of single sample """Scatter features of single sample.
Args: Args:
voxel_features (torch.Tensor): Voxel features in shape (N, M, C). voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
......
import torch.nn as nn from torch import nn as nn
import mmdet3d.ops.spconv as spconv
from mmdet3d.ops import make_sparse_convmodule from mmdet3d.ops import make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from ..registry import MIDDLE_ENCODERS from ..registry import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module() @MIDDLE_ENCODERS.register_module()
class SparseEncoder(nn.Module): class SparseEncoder(nn.Module):
"""Sparse encoder for Second """Sparse encoder for Second.
See https://arxiv.org/abs/1907.03670 for more detials. See https://arxiv.org/abs/1907.03670 for more detials.
...@@ -81,7 +81,7 @@ class SparseEncoder(nn.Module): ...@@ -81,7 +81,7 @@ class SparseEncoder(nn.Module):
conv_type='SparseConv3d') conv_type='SparseConv3d')
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseEncoder """Forward of SparseEncoder.
Args: Args:
voxel_features (torch.float32): shape [N, C] voxel_features (torch.float32): shape [N, C]
...@@ -113,7 +113,7 @@ class SparseEncoder(nn.Module): ...@@ -113,7 +113,7 @@ class SparseEncoder(nn.Module):
return spatial_features return spatial_features
def make_encoder_layers(self, make_block, norm_cfg, in_channels): def make_encoder_layers(self, make_block, norm_cfg, in_channels):
"""make encoder layers using sparse convs """make encoder layers using sparse convs.
Args: Args:
make_block (method): a bounded function to build blocks make_block (method): a bounded function to build blocks
......
import torch import torch
import torch.nn as nn from torch import nn as nn
import mmdet3d.ops.spconv as spconv
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from ..registry import MIDDLE_ENCODERS from ..registry import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module() @MIDDLE_ENCODERS.register_module()
class SparseUNet(nn.Module): class SparseUNet(nn.Module):
"""SparseUNet for PartA^2 """SparseUNet for PartA^2.
See https://arxiv.org/abs/1907.03670 for more detials. See https://arxiv.org/abs/1907.03670 for more detials.
...@@ -92,7 +92,7 @@ class SparseUNet(nn.Module): ...@@ -92,7 +92,7 @@ class SparseUNet(nn.Module):
conv_type='SparseConv3d') conv_type='SparseConv3d')
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet """Forward of SparseUNet.
Args: Args:
voxel_features (torch.float32): shape [N, C] voxel_features (torch.float32): shape [N, C]
...@@ -184,7 +184,7 @@ class SparseUNet(nn.Module): ...@@ -184,7 +184,7 @@ class SparseUNet(nn.Module):
return x return x
def make_encoder_layers(self, make_block, norm_cfg, in_channels): def make_encoder_layers(self, make_block, norm_cfg, in_channels):
"""make encoder layers using sparse convs """make encoder layers using sparse convs.
Args: Args:
make_block (method): a bounded function to build blocks make_block (method): a bounded function to build blocks
...@@ -230,7 +230,7 @@ class SparseUNet(nn.Module): ...@@ -230,7 +230,7 @@ class SparseUNet(nn.Module):
return out_channels return out_channels
def make_decoder_layers(self, make_block, norm_cfg, in_channels): def make_decoder_layers(self, make_block, norm_cfg, in_channels):
"""make decoder layers using sparse convs """make decoder layers using sparse convs.
Args: Args:
make_block (method): a bounded function to build blocks make_block (method): a bounded function to build blocks
......
import torch import torch
import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn as nn
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
......
import torch import torch
import torch.nn as nn
from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init, from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
is_norm, kaiming_init) is_norm, kaiming_init)
from torch import nn as nn
from mmdet.models import NECKS from mmdet.models import NECKS
@NECKS.register_module() @NECKS.register_module()
class SECONDFPN(nn.Module): class SECONDFPN(nn.Module):
"""FPN used in SECOND/PointPillars/PartA2/MVXNet """FPN used in SECOND/PointPillars/PartA2/MVXNet.
Args: Args:
in_channels (list[int]): Input channels of multi-scale feature maps in_channels (list[int]): Input channels of multi-scale feature maps
...@@ -46,7 +46,7 @@ class SECONDFPN(nn.Module): ...@@ -46,7 +46,7 @@ class SECONDFPN(nn.Module):
self.deblocks = nn.ModuleList(deblocks) self.deblocks = nn.ModuleList(deblocks)
def init_weights(self): def init_weights(self):
"""Initialize weights of FPN""" """Initialize weights of FPN."""
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m) kaiming_init(m)
...@@ -54,7 +54,7 @@ class SECONDFPN(nn.Module): ...@@ -54,7 +54,7 @@ class SECONDFPN(nn.Module):
constant_init(m, 1) constant_init(m, 1)
def forward(self, x): def forward(self, x):
"""Forward function """Forward function.
Args: Args:
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape. x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
......
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from torch import nn as nn
import torch.nn as nn
class Base3DRoIHead(nn.Module, metaclass=ABCMeta): class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
"""Base class for 3d RoIHeads""" """Base class for 3d RoIHeads."""
def __init__(self, def __init__(self,
bbox_head=None, bbox_head=None,
...@@ -51,7 +50,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -51,7 +50,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def init_assigner_sampler(self): def init_assigner_sampler(self):
"""Initialize assigner and sampler""" """Initialize assigner and sampler."""
pass pass
@abstractmethod @abstractmethod
...@@ -63,7 +62,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -63,7 +62,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
gt_labels, gt_labels,
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
**kwargs): **kwargs):
"""Forward function during training """Forward function during training.
Args: Args:
x (dict): Contains features from the first stage. x (dict): Contains features from the first stage.
......
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init, xavier_init from mmcv.cnn import ConvModule, normal_init, xavier_init
from torch import nn as nn
import mmdet3d.ops.spconv as spconv
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.ops import make_sparse_convmodule from mmdet3d.ops import make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS from mmdet.models import HEADS
...@@ -223,7 +223,7 @@ class PartA2BboxHead(nn.Module): ...@@ -223,7 +223,7 @@ class PartA2BboxHead(nn.Module):
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
"""Initialize weights of the bbox head""" """Initialize weights of the bbox head."""
for m in self.modules(): for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Conv1d)): if isinstance(m, (nn.Conv2d, nn.Conv1d)):
xavier_init(m, distribution='uniform') xavier_init(m, distribution='uniform')
...@@ -556,7 +556,7 @@ class PartA2BboxHead(nn.Module): ...@@ -556,7 +556,7 @@ class PartA2BboxHead(nn.Module):
nms_thr, nms_thr,
input_meta, input_meta,
use_rotate_nms=True): use_rotate_nms=True):
"""Multi-class NMS for box head """Multi-class NMS for box head.
Note: Note:
This function has large overlap with the `box3d_multiclass_nms` This function has large overlap with the `box3d_multiclass_nms`
......
import torch import torch
import torch.nn as nn from torch import nn as nn
import torch.nn.functional as F from torch.nn import functional as F
from mmdet3d.core.bbox.structures import rotation_3d_in_axis from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
...@@ -57,7 +57,6 @@ class PointwiseSemanticHead(nn.Module): ...@@ -57,7 +57,6 @@ class PointwiseSemanticHead(nn.Module):
Returns: Returns:
dict: part features, segmentation and part predictions. dict: part features, segmentation and part predictions.
""" """
seg_preds = self.seg_cls_layer(x) # (N, 1) seg_preds = self.seg_cls_layer(x) # (N, 1)
part_preds = self.seg_reg_layer(x) # (N, 3) part_preds = self.seg_reg_layer(x) # (N, 3)
...@@ -73,7 +72,8 @@ class PointwiseSemanticHead(nn.Module): ...@@ -73,7 +72,8 @@ class PointwiseSemanticHead(nn.Module):
seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats) seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats)
def get_targets_single(self, voxel_centers, gt_bboxes_3d, gt_labels_3d): def get_targets_single(self, voxel_centers, gt_bboxes_3d, gt_labels_3d):
"""generate segmentation and part prediction targets for a single sample """generate segmentation and part prediction targets for a single
sample.
Args: Args:
voxel_centers (torch.Tensor): shape [voxel_num, 3], voxel_centers (torch.Tensor): shape [voxel_num, 3],
...@@ -120,7 +120,7 @@ class PointwiseSemanticHead(nn.Module): ...@@ -120,7 +120,7 @@ class PointwiseSemanticHead(nn.Module):
return seg_targets, part_targets return seg_targets, part_targets
def get_targets(self, voxels_dict, gt_bboxes_3d, gt_labels_3d): def get_targets(self, voxels_dict, gt_bboxes_3d, gt_labels_3d):
"""generate segmentation and part prediction targets """generate segmentation and part prediction targets.
Args: Args:
voxel_centers (torch.Tensor): shape [voxel_num, 3], voxel_centers (torch.Tensor): shape [voxel_num, 3],
......
import torch.nn.functional as F from torch.nn import functional as F
from mmdet3d.core import AssignResult from mmdet3d.core import AssignResult
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
...@@ -10,7 +10,7 @@ from .base_3droi_head import Base3DRoIHead ...@@ -10,7 +10,7 @@ from .base_3droi_head import Base3DRoIHead
@HEADS.register_module() @HEADS.register_module()
class PartAggregationROIHead(Base3DRoIHead): class PartAggregationROIHead(Base3DRoIHead):
"""Part aggregation roi head for PartA2 """Part aggregation roi head for PartA2.
Args: Args:
semantic_head (ConfigDict): Config of semantic head. semantic_head (ConfigDict): Config of semantic head.
...@@ -44,21 +44,21 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -44,21 +44,21 @@ class PartAggregationROIHead(Base3DRoIHead):
self.init_assigner_sampler() self.init_assigner_sampler()
def init_weights(self, pretrained): def init_weights(self, pretrained):
"""Initialize weights, skip since ``PartAggregationROIHead`` """Initialize weights, skip since ``PartAggregationROIHead`` does not
does not need to initialize weights""" need to initialize weights."""
pass pass
def init_mask_head(self): def init_mask_head(self):
"""Initialize mask head, skip since ``PartAggregationROIHead`` """Initialize mask head, skip since ``PartAggregationROIHead`` does not
does not have one.""" have one."""
pass pass
def init_bbox_head(self, bbox_head): def init_bbox_head(self, bbox_head):
"""Initialize box head""" """Initialize box head."""
self.bbox_head = build_head(bbox_head) self.bbox_head = build_head(bbox_head)
def init_assigner_sampler(self): def init_assigner_sampler(self):
"""Initialize assigner and sampler""" """Initialize assigner and sampler."""
self.bbox_assigner = None self.bbox_assigner = None
self.bbox_sampler = None self.bbox_sampler = None
if self.train_cfg: if self.train_cfg:
...@@ -78,7 +78,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -78,7 +78,7 @@ class PartAggregationROIHead(Base3DRoIHead):
def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list, def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
gt_bboxes_3d, gt_labels_3d): gt_bboxes_3d, gt_labels_3d):
"""Training forward function of PartAggregationROIHead """Training forward function of PartAggregationROIHead.
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
...@@ -116,7 +116,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -116,7 +116,7 @@ class PartAggregationROIHead(Base3DRoIHead):
def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list, def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
**kwargs): **kwargs):
"""Simple testing forward function of PartAggregationROIHead """Simple testing forward function of PartAggregationROIHead.
Note: Note:
This function assumes that the batch size is 1 This function assumes that the batch size is 1
...@@ -216,7 +216,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -216,7 +216,7 @@ class PartAggregationROIHead(Base3DRoIHead):
return bbox_results return bbox_results
def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d):
"""Assign and sample proposals for training """Assign and sample proposals for training.
Args: Args:
proposal_list (list[dict]): Proposals produced by RPN. proposal_list (list[dict]): Proposals produced by RPN.
...@@ -290,7 +290,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -290,7 +290,7 @@ class PartAggregationROIHead(Base3DRoIHead):
def _semantic_forward_train(self, x, voxels_dict, gt_bboxes_3d, def _semantic_forward_train(self, x, voxels_dict, gt_bboxes_3d,
gt_labels_3d): gt_labels_3d):
"""Train semantic head """Train semantic head.
Args: Args:
x (torch.Tensor): Point-wise semantic features for segmentation x (torch.Tensor): Point-wise semantic features for segmentation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment