Commit 80b39bd0 authored by zhangwenwei's avatar zhangwenwei
Browse files

Reformat docstrings in code

parent 64d7fbc2
import torch import torch
import torch.nn as nn from torch import nn as nn
from mmdet3d import ops from mmdet3d import ops
from mmdet.models.builder import ROI_EXTRACTORS from mmdet.models.builder import ROI_EXTRACTORS
...@@ -7,7 +7,7 @@ from mmdet.models.builder import ROI_EXTRACTORS ...@@ -7,7 +7,7 @@ from mmdet.models.builder import ROI_EXTRACTORS
@ROI_EXTRACTORS.register_module() @ROI_EXTRACTORS.register_module()
class Single3DRoIAwareExtractor(nn.Module): class Single3DRoIAwareExtractor(nn.Module):
"""Point-wise roi-aware Extractor """Point-wise roi-aware Extractor.
Extract Point-wise roi features. Extract Point-wise roi features.
...@@ -29,7 +29,7 @@ class Single3DRoIAwareExtractor(nn.Module): ...@@ -29,7 +29,7 @@ class Single3DRoIAwareExtractor(nn.Module):
return roi_layers return roi_layers
def forward(self, feats, coordinate, batch_inds, rois): def forward(self, feats, coordinate, batch_inds, rois):
"""Extract point-wise roi features """Extract point-wise roi features.
Args: Args:
feats (FloatTensor): point-wise features with feats (FloatTensor): point-wise features with
......
...@@ -83,7 +83,7 @@ class PillarFeatureNet(nn.Module): ...@@ -83,7 +83,7 @@ class PillarFeatureNet(nn.Module):
self.point_cloud_range = point_cloud_range self.point_cloud_range = point_cloud_range
def forward(self, features, num_points, coors): def forward(self, features, num_points, coors):
"""Forward function """Forward function.
Args: Args:
features (torch.Tensor): Point features or raw points in shape features (torch.Tensor): Point features or raw points in shape
...@@ -136,7 +136,7 @@ class PillarFeatureNet(nn.Module): ...@@ -136,7 +136,7 @@ class PillarFeatureNet(nn.Module):
@VOXEL_ENCODERS.register_module() @VOXEL_ENCODERS.register_module()
class DynamicPillarFeatureNet(PillarFeatureNet): class DynamicPillarFeatureNet(PillarFeatureNet):
"""Pillar Feature Net using dynamic voxelization """Pillar Feature Net using dynamic voxelization.
The network prepares the pillar features and performs forward pass The network prepares the pillar features and performs forward pass
through PFNLayers. The main difference is that it is used for through PFNLayers. The main difference is that it is used for
...@@ -205,7 +205,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet): ...@@ -205,7 +205,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
voxel_size, point_cloud_range, average_points=True) voxel_size, point_cloud_range, average_points=True)
def map_voxel_center_to_point(self, pts_coors, voxel_mean, voxel_coors): def map_voxel_center_to_point(self, pts_coors, voxel_mean, voxel_coors):
"""Map the centers of voxels to its corresponding points """Map the centers of voxels to its corresponding points.
Args: Args:
pts_coors (torch.Tensor): The coordinates of each points, shape pts_coors (torch.Tensor): The coordinates of each points, shape
...@@ -244,7 +244,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet): ...@@ -244,7 +244,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
return center_per_point return center_per_point
def forward(self, features, coors): def forward(self, features, coors):
"""Forward function """Forward function.
Args: Args:
features (torch.Tensor): Point features or raw points in shape features (torch.Tensor): Point features or raw points in shape
......
...@@ -28,7 +28,7 @@ def get_paddings_indicator(actual_num, max_num, axis=0): ...@@ -28,7 +28,7 @@ def get_paddings_indicator(actual_num, max_num, axis=0):
class VFELayer(nn.Module): class VFELayer(nn.Module):
""" Voxel Feature Encoder layer. """Voxel Feature Encoder layer.
The voxel encoder is composed of a series of these layers. The voxel encoder is composed of a series of these layers.
This module do not support average pooling and only support to use This module do not support average pooling and only support to use
...@@ -59,7 +59,7 @@ class VFELayer(nn.Module): ...@@ -59,7 +59,7 @@ class VFELayer(nn.Module):
self.linear = nn.Linear(in_channels, out_channels, bias=False) self.linear = nn.Linear(in_channels, out_channels, bias=False)
def forward(self, inputs): def forward(self, inputs):
"""Forward function """Forward function.
Args: Args:
inputs (torch.Tensor): Voxels features of shape (N, M, C). inputs (torch.Tensor): Voxels features of shape (N, M, C).
...@@ -100,7 +100,7 @@ class VFELayer(nn.Module): ...@@ -100,7 +100,7 @@ class VFELayer(nn.Module):
class PFNLayer(nn.Module): class PFNLayer(nn.Module):
""" Pillar Feature Net Layer. """Pillar Feature Net Layer.
The Pillar Feature Net is composed of a series of these layers, but the The Pillar Feature Net is composed of a series of these layers, but the
PointPillars paper results only used a single PFNLayer. PointPillars paper results only used a single PFNLayer.
...@@ -136,7 +136,7 @@ class PFNLayer(nn.Module): ...@@ -136,7 +136,7 @@ class PFNLayer(nn.Module):
self.mode = mode self.mode = mode
def forward(self, inputs, num_voxels=None, aligned_distance=None): def forward(self, inputs, num_voxels=None, aligned_distance=None):
"""Forward function """Forward function.
Args: Args:
inputs (torch.Tensor): Pillar/Voxel inputs with shape (N, M, C). inputs (torch.Tensor): Pillar/Voxel inputs with shape (N, M, C).
......
...@@ -10,7 +10,7 @@ from .utils import VFELayer, get_paddings_indicator ...@@ -10,7 +10,7 @@ from .utils import VFELayer, get_paddings_indicator
@VOXEL_ENCODERS.register_module() @VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module): class HardSimpleVFE(nn.Module):
"""Simple voxel feature encoder used in SECOND """Simple voxel feature encoder used in SECOND.
It simply averages the values of points in a voxel. It simply averages the values of points in a voxel.
""" """
...@@ -19,7 +19,7 @@ class HardSimpleVFE(nn.Module): ...@@ -19,7 +19,7 @@ class HardSimpleVFE(nn.Module):
super(HardSimpleVFE, self).__init__() super(HardSimpleVFE, self).__init__()
def forward(self, features, num_points, coors): def forward(self, features, num_points, coors):
"""Forward function """Forward function.
Args: Args:
features (torch.Tensor): point features in shape features (torch.Tensor): point features in shape
...@@ -39,7 +39,7 @@ class HardSimpleVFE(nn.Module): ...@@ -39,7 +39,7 @@ class HardSimpleVFE(nn.Module):
@VOXEL_ENCODERS.register_module() @VOXEL_ENCODERS.register_module()
class DynamicSimpleVFE(nn.Module): class DynamicSimpleVFE(nn.Module):
"""Simple dynamic voxel feature encoder used in DV-SECOND """Simple dynamic voxel feature encoder used in DV-SECOND.
It simply averages the values of points in a voxel. It simply averages the values of points in a voxel.
But the number of points in a voxel is dynamic and varies. But the number of points in a voxel is dynamic and varies.
...@@ -57,7 +57,7 @@ class DynamicSimpleVFE(nn.Module): ...@@ -57,7 +57,7 @@ class DynamicSimpleVFE(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, features, coors): def forward(self, features, coors):
"""Forward function """Forward function.
Args: Args:
features (torch.Tensor): point features in shape features (torch.Tensor): point features in shape
...@@ -76,7 +76,7 @@ class DynamicSimpleVFE(nn.Module): ...@@ -76,7 +76,7 @@ class DynamicSimpleVFE(nn.Module):
@VOXEL_ENCODERS.register_module() @VOXEL_ENCODERS.register_module()
class DynamicVFE(nn.Module): class DynamicVFE(nn.Module):
"""Dynamic Voxel feature encoder used in DV-SECOND """Dynamic Voxel feature encoder used in DV-SECOND.
It encodes features of voxels and their points. It could also fuse It encodes features of voxels and their points. It could also fuse
image feature into voxel features in a point-wise manner. image feature into voxel features in a point-wise manner.
...@@ -211,7 +211,7 @@ class DynamicVFE(nn.Module): ...@@ -211,7 +211,7 @@ class DynamicVFE(nn.Module):
points=None, points=None,
img_feats=None, img_feats=None,
img_metas=None): img_metas=None):
"""Forward functions """Forward functions.
Args: Args:
features (torch.Tensor): Features of voxels, shape is NxC. features (torch.Tensor): Features of voxels, shape is NxC.
...@@ -274,7 +274,7 @@ class DynamicVFE(nn.Module): ...@@ -274,7 +274,7 @@ class DynamicVFE(nn.Module):
@VOXEL_ENCODERS.register_module() @VOXEL_ENCODERS.register_module()
class HardVFE(nn.Module): class HardVFE(nn.Module):
"""Voxel feature encoder used in DV-SECOND """Voxel feature encoder used in DV-SECOND.
It encodes features of voxels and their points. It could also fuse It encodes features of voxels and their points. It could also fuse
image feature into voxel features in a point-wise manner. image feature into voxel features in a point-wise manner.
...@@ -374,7 +374,7 @@ class HardVFE(nn.Module): ...@@ -374,7 +374,7 @@ class HardVFE(nn.Module):
coors, coors,
img_feats=None, img_feats=None,
img_metas=None): img_metas=None):
"""Forward functions """Forward functions.
Args: Args:
features (torch.Tensor): Features of voxels, shape is MxNxC. features (torch.Tensor): Features of voxels, shape is MxNxC.
......
...@@ -5,7 +5,7 @@ from . import ball_query_ext ...@@ -5,7 +5,7 @@ from . import ball_query_ext
class BallQuery(Function): class BallQuery(Function):
"""Ball Query """Ball Query.
Find nearby points in spherical space. Find nearby points in spherical space.
""" """
......
...@@ -7,8 +7,8 @@ from . import furthest_point_sample_ext ...@@ -7,8 +7,8 @@ from . import furthest_point_sample_ext
class FurthestPointSampling(Function): class FurthestPointSampling(Function):
"""Furthest Point Sampling. """Furthest Point Sampling.
Uses iterative furthest point sampling to select a set of Uses iterative furthest point sampling to select a set of features whose
features whose corresponding points have the furthest distance. corresponding points have the furthest distance.
""" """
@staticmethod @staticmethod
......
...@@ -5,7 +5,7 @@ from . import gather_points_ext ...@@ -5,7 +5,7 @@ from . import gather_points_ext
class GatherPoints(Function): class GatherPoints(Function):
"""Gather Points """Gather Points.
Gather points with given index. Gather points with given index.
""" """
......
from typing import Tuple
import torch import torch
import torch.nn as nn from torch import nn as nn
from torch.autograd import Function from torch.autograd import Function
from typing import Tuple
from ..ball_query import ball_query from ..ball_query import ball_query
from . import group_points_ext from . import group_points_ext
...@@ -49,7 +48,7 @@ class QueryAndGroup(nn.Module): ...@@ -49,7 +48,7 @@ class QueryAndGroup(nn.Module):
assert self.uniform_sample assert self.uniform_sample
def forward(self, points_xyz, center_xyz, features=None): def forward(self, points_xyz, center_xyz, features=None):
"""forward """forward.
Args: Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
......
from typing import Tuple
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from typing import Tuple
from . import interpolate_ext from . import interpolate_ext
...@@ -11,7 +10,7 @@ class ThreeInterpolate(Function): ...@@ -11,7 +10,7 @@ class ThreeInterpolate(Function):
@staticmethod @staticmethod
def forward(ctx, features: torch.Tensor, indices: torch.Tensor, def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor: weight: torch.Tensor) -> torch.Tensor:
"""Performs weighted linear interpolation on 3 features """Performs weighted linear interpolation on 3 features.
Args: Args:
features (Tensor): (B, C, M) Features descriptors to be features (Tensor): (B, C, M) Features descriptors to be
...@@ -40,7 +39,7 @@ class ThreeInterpolate(Function): ...@@ -40,7 +39,7 @@ class ThreeInterpolate(Function):
def backward( def backward(
ctx, grad_out: torch.Tensor ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Backward of three interpolate """Backward of three interpolate.
Args: Args:
grad_out (Tensor): (B, C, N) tensor with gradients of outputs grad_out (Tensor): (B, C, N) tensor with gradients of outputs
......
from typing import Tuple
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from typing import Tuple
from . import interpolate_ext from . import interpolate_ext
...@@ -11,7 +10,8 @@ class ThreeNN(Function): ...@@ -11,7 +10,8 @@ class ThreeNN(Function):
@staticmethod @staticmethod
def forward(ctx, target: torch.Tensor, def forward(ctx, target: torch.Tensor,
source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Find the top-3 nearest neighbors of the target set from the source set. """Find the top-3 nearest neighbors of the target set from the source
set.
Args: Args:
target (Tensor): shape (B, N, 3), points set that needs to target (Tensor): shape (B, N, 3), points set that needs to
......
import torch import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.cnn import NORM_LAYERS from mmcv.cnn import NORM_LAYERS
from torch import distributed as dist
from torch import nn as nn
from torch.autograd.function import Function from torch.autograd.function import Function
...@@ -25,7 +25,7 @@ class AllReduce(Function): ...@@ -25,7 +25,7 @@ class AllReduce(Function):
@NORM_LAYERS.register_module('naiveSyncBN1d') @NORM_LAYERS.register_module('naiveSyncBN1d')
class NaiveSyncBatchNorm1d(nn.BatchNorm1d): class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
"""Syncronized Batch Normalization for 3D Tensors """Syncronized Batch Normalization for 3D Tensors.
Note: Note:
This implementation is modified from This implementation is modified from
...@@ -70,7 +70,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): ...@@ -70,7 +70,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
@NORM_LAYERS.register_module('naiveSyncBN2d') @NORM_LAYERS.register_module('naiveSyncBN2d')
class NaiveSyncBatchNorm2d(nn.BatchNorm2d): class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
"""Syncronized Batch Normalization for 4D Tensors """Syncronized Batch Normalization for 4D Tensors.
Note: Note:
This implementation is modified from This implementation is modified from
......
from typing import List
import torch import torch
import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn as nn
from typing import List
from mmdet3d.ops import three_interpolate, three_nn from mmdet3d.ops import three_interpolate, three_nn
......
from typing import List
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn as nn
from torch.nn import functional as F
from typing import List
from mmdet3d.ops import (GroupAll, QueryAndGroup, furthest_point_sample, from mmdet3d.ops import (GroupAll, QueryAndGroup, furthest_point_sample,
gather_points) gather_points)
class PointSAModuleMSG(nn.Module): class PointSAModuleMSG(nn.Module):
"""Point set abstraction module with multi-scale grouping used in Pointnets. """Point set abstraction module with multi-scale grouping used in
Pointnets.
Args: Args:
num_point (int): Number of points. num_point (int): Number of points.
......
import mmcv import mmcv
import torch import torch
import torch.nn as nn from torch import nn as nn
from torch.autograd import Function from torch.autograd import Function
from . import roiaware_pool3d_ext from . import roiaware_pool3d_ext
...@@ -24,7 +24,7 @@ class RoIAwarePool3d(nn.Module): ...@@ -24,7 +24,7 @@ class RoIAwarePool3d(nn.Module):
self.mode = pool_method_map[mode] self.mode = pool_method_map[mode]
def forward(self, rois, pts, pts_feature): def forward(self, rois, pts, pts_feature):
"""RoIAwarePool3d module forward """RoIAwarePool3d module forward.
Args: Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate, rois (torch.Tensor): [N, 7],in LiDAR coordinate,
...@@ -46,7 +46,7 @@ class RoIAwarePool3dFunction(Function): ...@@ -46,7 +46,7 @@ class RoIAwarePool3dFunction(Function):
@staticmethod @staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode): mode):
"""RoIAwarePool3d function forward """RoIAwarePool3d function forward.
Args: Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate, rois (torch.Tensor): [N, 7], in LiDAR coordinate,
...@@ -89,7 +89,7 @@ class RoIAwarePool3dFunction(Function): ...@@ -89,7 +89,7 @@ class RoIAwarePool3dFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
"""RoIAwarePool3d function forward """RoIAwarePool3d function forward.
Args: Args:
grad_out (torch.Tensor): [N, out_x, out_y, out_z, C] grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
......
...@@ -11,9 +11,7 @@ ...@@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import CONV_LAYERS from mmcv.cnn import CONV_LAYERS
......
...@@ -11,11 +11,9 @@ ...@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
from collections import OrderedDict
import torch import torch
from collections import OrderedDict
from torch import nn from torch import nn
from .structure import SparseConvTensor from .structure import SparseConvTensor
...@@ -46,9 +44,8 @@ def _mean_update(vals, m_vals, t): ...@@ -46,9 +44,8 @@ def _mean_update(vals, m_vals, t):
class SparseModule(nn.Module): class SparseModule(nn.Module):
""" place holder, """place holder, All module subclass from this will take sptensor in
All module subclass from this will take sptensor in SparseSequential. SparseSequential."""
"""
pass pass
...@@ -140,7 +137,9 @@ class SparseSequential(SparseModule): ...@@ -140,7 +137,9 @@ class SparseSequential(SparseModule):
return input return input
def fused(self): def fused(self):
"""don't use this. no effect. """don't use this.
no effect.
""" """
from .conv import SparseConvolution from .conv import SparseConvolution
mods = [v for k, v in self._modules.items()] mods = [v for k, v in self._modules.items()]
...@@ -189,16 +188,14 @@ class SparseSequential(SparseModule): ...@@ -189,16 +188,14 @@ class SparseSequential(SparseModule):
class ToDense(SparseModule): class ToDense(SparseModule):
"""convert SparseConvTensor to NCHW dense tensor. """convert SparseConvTensor to NCHW dense tensor."""
"""
def forward(self, x: SparseConvTensor): def forward(self, x: SparseConvTensor):
return x.dense() return x.dense()
class RemoveGrid(SparseModule): class RemoveGrid(SparseModule):
"""remove pre-allocated grid buffer. """remove pre-allocated grid buffer."""
"""
def forward(self, x: SparseConvTensor): def forward(self, x: SparseConvTensor):
x.grid = None x.grid = None
......
...@@ -4,9 +4,9 @@ import torch ...@@ -4,9 +4,9 @@ import torch
def scatter_nd(indices, updates, shape): def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd. """pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported this function don't contain except handle code. so use this carefully when
in tensorflow. indice repeats, don't support repeat add which is supported in tensorflow.
""" """
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device) ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1] ndim = indices.shape[-1]
......
...@@ -11,10 +11,8 @@ ...@@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
import numpy as np import numpy as np
import unittest
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
...@@ -26,6 +24,7 @@ class TestCase(unittest.TestCase): ...@@ -26,6 +24,7 @@ class TestCase(unittest.TestCase):
def assertAllEqual(self, a, b): def assertAllEqual(self, a, b):
"""Asserts that two numpy arrays have the same values. """Asserts that two numpy arrays have the same values.
Args: Args:
a: the expected numpy ndarray or anything can be converted to one. a: the expected numpy ndarray or anything can be converted to one.
b: the actual numpy ndarray or anything can be converted to one. b: the actual numpy ndarray or anything can be converted to one.
...@@ -56,6 +55,7 @@ class TestCase(unittest.TestCase): ...@@ -56,6 +55,7 @@ class TestCase(unittest.TestCase):
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
"""Asserts that two numpy arrays, or dicts of same, have near values. """Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts. This does not support nested dicts.
Args: Args:
a: The expected numpy ndarray (or anything can be converted to one), or a: The expected numpy ndarray (or anything can be converted to one), or
......
import os.path as osp
import subprocess
import sys
from collections import defaultdict
import cv2 import cv2
import mmcv import mmcv
import subprocess
import sys
import torch import torch
import torchvision import torchvision
from collections import defaultdict
from os import path as osp
import mmdet import mmdet
import mmdet3d import mmdet3d
......
...@@ -4,7 +4,7 @@ from mmdet.core import BitmapMasks, PolygonMasks ...@@ -4,7 +4,7 @@ from mmdet.core import BitmapMasks, PolygonMasks
def _get_config_directory(): def _get_config_directory():
""" Find the predefined detector config directory """ """Find the predefined detector config directory."""
try: try:
# Assume we are running in the source mmdetection repo # Assume we are running in the source mmdetection repo
repo_dpath = dirname(dirname(__file__)) repo_dpath = dirname(dirname(__file__))
...@@ -19,10 +19,10 @@ def _get_config_directory(): ...@@ -19,10 +19,10 @@ def _get_config_directory():
def test_config_build_detector(): def test_config_build_detector():
""" """Test that all detection models defined in the configs can be
Test that all detection models defined in the configs can be initialized. initialized."""
"""
from mmcv import Config from mmcv import Config
from mmdet3d.models import build_detector from mmdet3d.models import build_detector
config_dpath = _get_config_directory() config_dpath = _get_config_directory()
...@@ -74,10 +74,10 @@ def test_config_build_detector(): ...@@ -74,10 +74,10 @@ def test_config_build_detector():
def test_config_build_pipeline(): def test_config_build_pipeline():
""" """Test that all detection models defined in the configs can be
Test that all detection models defined in the configs can be initialized. initialized."""
"""
from mmcv import Config from mmcv import Config
from mmdet3d.datasets.pipelines import Compose from mmdet3d.datasets.pipelines import Compose
config_dpath = _get_config_directory() config_dpath = _get_config_directory()
...@@ -102,14 +102,15 @@ def test_config_build_pipeline(): ...@@ -102,14 +102,15 @@ def test_config_build_pipeline():
def test_config_data_pipeline(): def test_config_data_pipeline():
""" """Test whether the data pipeline is valid and can process corner cases.
Test whether the data pipeline is valid and can process corner cases.
CommandLine: CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline xdoctest -m tests/test_config.py test_config_build_data_pipeline
""" """
import numpy as np
from mmcv import Config from mmcv import Config
from mmdet3d.datasets.pipelines import Compose from mmdet3d.datasets.pipelines import Compose
import numpy as np
config_dpath = _get_config_directory() config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath)) print('Found config_dpath = {!r}'.format(config_dpath))
...@@ -262,7 +263,7 @@ def _check_roi_head(config, head): ...@@ -262,7 +263,7 @@ def _check_roi_head(config, head):
def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None): def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
import torch.nn as nn from torch import nn as nn
if isinstance(roi_extractor, nn.ModuleList): if isinstance(roi_extractor, nn.ModuleList):
if prev_roi_extractor: if prev_roi_extractor:
prev_roi_extractor = prev_roi_extractor[0] prev_roi_extractor = prev_roi_extractor[0]
...@@ -289,7 +290,7 @@ def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None): ...@@ -289,7 +290,7 @@ def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
def _check_mask_head(mask_cfg, mask_head): def _check_mask_head(mask_cfg, mask_head):
import torch.nn as nn from torch import nn as nn
if isinstance(mask_cfg, list): if isinstance(mask_cfg, list):
for single_mask_cfg, single_mask_head in zip(mask_cfg, mask_head): for single_mask_cfg, single_mask_head in zip(mask_cfg, mask_head):
_check_mask_head(single_mask_cfg, single_mask_head) _check_mask_head(single_mask_cfg, single_mask_head)
...@@ -307,7 +308,7 @@ def _check_mask_head(mask_cfg, mask_head): ...@@ -307,7 +308,7 @@ def _check_mask_head(mask_cfg, mask_head):
def _check_bbox_head(bbox_cfg, bbox_head): def _check_bbox_head(bbox_cfg, bbox_head):
import torch.nn as nn from torch import nn as nn
if isinstance(bbox_cfg, list): if isinstance(bbox_cfg, list):
for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head): for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head):
_check_bbox_head(single_bbox_cfg, single_bbox_head) _check_bbox_head(single_bbox_cfg, single_bbox_head)
...@@ -357,7 +358,7 @@ def _check_parta2_roi_extractor(config, roi_extractor): ...@@ -357,7 +358,7 @@ def _check_parta2_roi_extractor(config, roi_extractor):
def _check_parta2_bbox_head(bbox_cfg, bbox_head): def _check_parta2_bbox_head(bbox_cfg, bbox_head):
import torch.nn as nn from torch import nn as nn
if isinstance(bbox_cfg, list): if isinstance(bbox_cfg, list):
for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head): for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head):
_check_bbox_head(single_bbox_cfg, single_bbox_head) _check_bbox_head(single_bbox_cfg, single_bbox_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment