Commit 80b39bd0 authored by zhangwenwei's avatar zhangwenwei
Browse files

Reformat docstrings in code

parent 64d7fbc2
import torch
import torch.nn as nn
from torch import nn as nn
from mmdet3d import ops
from mmdet.models.builder import ROI_EXTRACTORS
......@@ -7,7 +7,7 @@ from mmdet.models.builder import ROI_EXTRACTORS
@ROI_EXTRACTORS.register_module()
class Single3DRoIAwareExtractor(nn.Module):
"""Point-wise roi-aware Extractor
"""Point-wise roi-aware Extractor.
Extract Point-wise roi features.
......@@ -29,7 +29,7 @@ class Single3DRoIAwareExtractor(nn.Module):
return roi_layers
def forward(self, feats, coordinate, batch_inds, rois):
"""Extract point-wise roi features
"""Extract point-wise roi features.
Args:
feats (FloatTensor): point-wise features with
......
......@@ -83,7 +83,7 @@ class PillarFeatureNet(nn.Module):
self.point_cloud_range = point_cloud_range
def forward(self, features, num_points, coors):
"""Forward function
"""Forward function.
Args:
features (torch.Tensor): Point features or raw points in shape
......@@ -136,7 +136,7 @@ class PillarFeatureNet(nn.Module):
@VOXEL_ENCODERS.register_module()
class DynamicPillarFeatureNet(PillarFeatureNet):
"""Pillar Feature Net using dynamic voxelization
"""Pillar Feature Net using dynamic voxelization.
The network prepares the pillar features and performs forward pass
through PFNLayers. The main difference is that it is used for
......@@ -205,7 +205,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
voxel_size, point_cloud_range, average_points=True)
def map_voxel_center_to_point(self, pts_coors, voxel_mean, voxel_coors):
"""Map the centers of voxels to its corresponding points
"""Map the centers of voxels to its corresponding points.
Args:
pts_coors (torch.Tensor): The coordinates of each points, shape
......@@ -244,7 +244,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
return center_per_point
def forward(self, features, coors):
"""Forward function
"""Forward function.
Args:
features (torch.Tensor): Point features or raw points in shape
......
......@@ -28,7 +28,7 @@ def get_paddings_indicator(actual_num, max_num, axis=0):
class VFELayer(nn.Module):
""" Voxel Feature Encoder layer.
"""Voxel Feature Encoder layer.
The voxel encoder is composed of a series of these layers.
This module do not support average pooling and only support to use
......@@ -59,7 +59,7 @@ class VFELayer(nn.Module):
self.linear = nn.Linear(in_channels, out_channels, bias=False)
def forward(self, inputs):
"""Forward function
"""Forward function.
Args:
inputs (torch.Tensor): Voxels features of shape (N, M, C).
......@@ -100,7 +100,7 @@ class VFELayer(nn.Module):
class PFNLayer(nn.Module):
""" Pillar Feature Net Layer.
"""Pillar Feature Net Layer.
The Pillar Feature Net is composed of a series of these layers, but the
PointPillars paper results only used a single PFNLayer.
......@@ -136,7 +136,7 @@ class PFNLayer(nn.Module):
self.mode = mode
def forward(self, inputs, num_voxels=None, aligned_distance=None):
"""Forward function
"""Forward function.
Args:
inputs (torch.Tensor): Pillar/Voxel inputs with shape (N, M, C).
......
......@@ -10,7 +10,7 @@ from .utils import VFELayer, get_paddings_indicator
@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):
"""Simple voxel feature encoder used in SECOND
"""Simple voxel feature encoder used in SECOND.
It simply averages the values of points in a voxel.
"""
......@@ -19,7 +19,7 @@ class HardSimpleVFE(nn.Module):
super(HardSimpleVFE, self).__init__()
def forward(self, features, num_points, coors):
"""Forward function
"""Forward function.
Args:
features (torch.Tensor): point features in shape
......@@ -39,7 +39,7 @@ class HardSimpleVFE(nn.Module):
@VOXEL_ENCODERS.register_module()
class DynamicSimpleVFE(nn.Module):
"""Simple dynamic voxel feature encoder used in DV-SECOND
"""Simple dynamic voxel feature encoder used in DV-SECOND.
It simply averages the values of points in a voxel.
But the number of points in a voxel is dynamic and varies.
......@@ -57,7 +57,7 @@ class DynamicSimpleVFE(nn.Module):
@torch.no_grad()
def forward(self, features, coors):
"""Forward function
"""Forward function.
Args:
features (torch.Tensor): point features in shape
......@@ -76,7 +76,7 @@ class DynamicSimpleVFE(nn.Module):
@VOXEL_ENCODERS.register_module()
class DynamicVFE(nn.Module):
"""Dynamic Voxel feature encoder used in DV-SECOND
"""Dynamic Voxel feature encoder used in DV-SECOND.
It encodes features of voxels and their points. It could also fuse
image feature into voxel features in a point-wise manner.
......@@ -211,7 +211,7 @@ class DynamicVFE(nn.Module):
points=None,
img_feats=None,
img_metas=None):
"""Forward functions
"""Forward functions.
Args:
features (torch.Tensor): Features of voxels, shape is NxC.
......@@ -274,7 +274,7 @@ class DynamicVFE(nn.Module):
@VOXEL_ENCODERS.register_module()
class HardVFE(nn.Module):
"""Voxel feature encoder used in DV-SECOND
"""Voxel feature encoder used in DV-SECOND.
It encodes features of voxels and their points. It could also fuse
image feature into voxel features in a point-wise manner.
......@@ -374,7 +374,7 @@ class HardVFE(nn.Module):
coors,
img_feats=None,
img_metas=None):
"""Forward functions
"""Forward functions.
Args:
features (torch.Tensor): Features of voxels, shape is MxNxC.
......
......@@ -5,7 +5,7 @@ from . import ball_query_ext
class BallQuery(Function):
"""Ball Query
"""Ball Query.
Find nearby points in spherical space.
"""
......
......@@ -7,8 +7,8 @@ from . import furthest_point_sample_ext
class FurthestPointSampling(Function):
"""Furthest Point Sampling.
Uses iterative furthest point sampling to select a set of
features whose corresponding points have the furthest distance.
Uses iterative furthest point sampling to select a set of features whose
corresponding points have the furthest distance.
"""
@staticmethod
......
......@@ -5,7 +5,7 @@ from . import gather_points_ext
class GatherPoints(Function):
"""Gather Points
"""Gather Points.
Gather points with given index.
"""
......
from typing import Tuple
import torch
import torch.nn as nn
from torch import nn as nn
from torch.autograd import Function
from typing import Tuple
from ..ball_query import ball_query
from . import group_points_ext
......@@ -49,7 +48,7 @@ class QueryAndGroup(nn.Module):
assert self.uniform_sample
def forward(self, points_xyz, center_xyz, features=None):
"""forward
"""forward.
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
......
from typing import Tuple
import torch
from torch.autograd import Function
from typing import Tuple
from . import interpolate_ext
......@@ -11,7 +10,7 @@ class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor:
"""Performs weighted linear interpolation on 3 features
"""Performs weighted linear interpolation on 3 features.
Args:
features (Tensor): (B, C, M) Features descriptors to be
......@@ -40,7 +39,7 @@ class ThreeInterpolate(Function):
def backward(
ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Backward of three interpolate
"""Backward of three interpolate.
Args:
grad_out (Tensor): (B, C, N) tensor with gradients of outputs
......
from typing import Tuple
import torch
from torch.autograd import Function
from typing import Tuple
from . import interpolate_ext
......@@ -11,7 +10,8 @@ class ThreeNN(Function):
@staticmethod
def forward(ctx, target: torch.Tensor,
source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Find the top-3 nearest neighbors of the target set from the source set.
"""Find the top-3 nearest neighbors of the target set from the source
set.
Args:
target (Tensor): shape (B, N, 3), points set that needs to
......
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.cnn import NORM_LAYERS
from torch import distributed as dist
from torch import nn as nn
from torch.autograd.function import Function
......@@ -25,7 +25,7 @@ class AllReduce(Function):
@NORM_LAYERS.register_module('naiveSyncBN1d')
class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
"""Syncronized Batch Normalization for 3D Tensors
"""Syncronized Batch Normalization for 3D Tensors.
Note:
This implementation is modified from
......@@ -70,7 +70,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
@NORM_LAYERS.register_module('naiveSyncBN2d')
class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
"""Syncronized Batch Normalization for 4D Tensors
"""Syncronized Batch Normalization for 4D Tensors.
Note:
This implementation is modified from
......
from typing import List
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from torch import nn as nn
from typing import List
from mmdet3d.ops import three_interpolate, three_nn
......
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from torch import nn as nn
from torch.nn import functional as F
from typing import List
from mmdet3d.ops import (GroupAll, QueryAndGroup, furthest_point_sample,
gather_points)
class PointSAModuleMSG(nn.Module):
"""Point set abstraction module with multi-scale grouping used in Pointnets.
"""Point set abstraction module with multi-scale grouping used in
Pointnets.
Args:
num_point (int): Number of points.
......
import mmcv
import torch
import torch.nn as nn
from torch import nn as nn
from torch.autograd import Function
from . import roiaware_pool3d_ext
......@@ -24,7 +24,7 @@ class RoIAwarePool3d(nn.Module):
self.mode = pool_method_map[mode]
def forward(self, rois, pts, pts_feature):
"""RoIAwarePool3d module forward
"""RoIAwarePool3d module forward.
Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
......@@ -46,7 +46,7 @@ class RoIAwarePool3dFunction(Function):
@staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode):
"""RoIAwarePool3d function forward
"""RoIAwarePool3d function forward.
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
......@@ -89,7 +89,7 @@ class RoIAwarePool3dFunction(Function):
@staticmethod
def backward(ctx, grad_out):
"""RoIAwarePool3d function forward
"""RoIAwarePool3d function forward.
Args:
grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
......
......@@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
from mmcv.cnn import CONV_LAYERS
......
......@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import OrderedDict
import torch
from collections import OrderedDict
from torch import nn
from .structure import SparseConvTensor
......@@ -46,9 +44,8 @@ def _mean_update(vals, m_vals, t):
class SparseModule(nn.Module):
""" place holder,
All module subclass from this will take sptensor in SparseSequential.
"""
"""place holder, All module subclass from this will take sptensor in
SparseSequential."""
pass
......@@ -140,7 +137,9 @@ class SparseSequential(SparseModule):
return input
def fused(self):
"""don't use this. no effect.
"""don't use this.
no effect.
"""
from .conv import SparseConvolution
mods = [v for k, v in self._modules.items()]
......@@ -189,16 +188,14 @@ class SparseSequential(SparseModule):
class ToDense(SparseModule):
"""convert SparseConvTensor to NCHW dense tensor.
"""
"""convert SparseConvTensor to NCHW dense tensor."""
def forward(self, x: SparseConvTensor):
return x.dense()
class RemoveGrid(SparseModule):
"""remove pre-allocated grid buffer.
"""
"""remove pre-allocated grid buffer."""
def forward(self, x: SparseConvTensor):
x.grid = None
......
......@@ -4,9 +4,9 @@ import torch
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
this function don't contain except handle code. so use this carefully when
indice repeats, don't support repeat add which is supported in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
......
......@@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import unittest
class TestCase(unittest.TestCase):
......@@ -26,6 +24,7 @@ class TestCase(unittest.TestCase):
def assertAllEqual(self, a, b):
"""Asserts that two numpy arrays have the same values.
Args:
a: the expected numpy ndarray or anything can be converted to one.
b: the actual numpy ndarray or anything can be converted to one.
......@@ -56,6 +55,7 @@ class TestCase(unittest.TestCase):
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
"""Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts.
Args:
a: The expected numpy ndarray (or anything can be converted to one), or
......
import os.path as osp
import subprocess
import sys
from collections import defaultdict
import cv2
import mmcv
import subprocess
import sys
import torch
import torchvision
from collections import defaultdict
from os import path as osp
import mmdet
import mmdet3d
......
......@@ -4,7 +4,7 @@ from mmdet.core import BitmapMasks, PolygonMasks
def _get_config_directory():
""" Find the predefined detector config directory """
"""Find the predefined detector config directory."""
try:
# Assume we are running in the source mmdetection repo
repo_dpath = dirname(dirname(__file__))
......@@ -19,10 +19,10 @@ def _get_config_directory():
def test_config_build_detector():
"""
Test that all detection models defined in the configs can be initialized.
"""
"""Test that all detection models defined in the configs can be
initialized."""
from mmcv import Config
from mmdet3d.models import build_detector
config_dpath = _get_config_directory()
......@@ -74,10 +74,10 @@ def test_config_build_detector():
def test_config_build_pipeline():
"""
Test that all detection models defined in the configs can be initialized.
"""
"""Test that all detection models defined in the configs can be
initialized."""
from mmcv import Config
from mmdet3d.datasets.pipelines import Compose
config_dpath = _get_config_directory()
......@@ -102,14 +102,15 @@ def test_config_build_pipeline():
def test_config_data_pipeline():
"""
Test whether the data pipeline is valid and can process corner cases.
"""Test whether the data pipeline is valid and can process corner cases.
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
import numpy as np
from mmcv import Config
from mmdet3d.datasets.pipelines import Compose
import numpy as np
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
......@@ -262,7 +263,7 @@ def _check_roi_head(config, head):
def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
import torch.nn as nn
from torch import nn as nn
if isinstance(roi_extractor, nn.ModuleList):
if prev_roi_extractor:
prev_roi_extractor = prev_roi_extractor[0]
......@@ -289,7 +290,7 @@ def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
def _check_mask_head(mask_cfg, mask_head):
import torch.nn as nn
from torch import nn as nn
if isinstance(mask_cfg, list):
for single_mask_cfg, single_mask_head in zip(mask_cfg, mask_head):
_check_mask_head(single_mask_cfg, single_mask_head)
......@@ -307,7 +308,7 @@ def _check_mask_head(mask_cfg, mask_head):
def _check_bbox_head(bbox_cfg, bbox_head):
import torch.nn as nn
from torch import nn as nn
if isinstance(bbox_cfg, list):
for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head):
_check_bbox_head(single_bbox_cfg, single_bbox_head)
......@@ -357,7 +358,7 @@ def _check_parta2_roi_extractor(config, roi_extractor):
def _check_parta2_bbox_head(bbox_cfg, bbox_head):
import torch.nn as nn
from torch import nn as nn
if isinstance(bbox_cfg, list):
for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head):
_check_bbox_head(single_bbox_cfg, single_bbox_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment