Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
80b39bd0
Commit
80b39bd0
authored
Jul 04, 2020
by
zhangwenwei
Browse files
Reformat docstrings in code
parent
64d7fbc2
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
80 additions
and
88 deletions
+80
-88
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
...els/roi_heads/roi_extractors/single_roiaware_extractor.py
+3
-3
mmdet3d/models/voxel_encoders/pillar_encoder.py
mmdet3d/models/voxel_encoders/pillar_encoder.py
+4
-4
mmdet3d/models/voxel_encoders/utils.py
mmdet3d/models/voxel_encoders/utils.py
+4
-4
mmdet3d/models/voxel_encoders/voxel_encoder.py
mmdet3d/models/voxel_encoders/voxel_encoder.py
+8
-8
mmdet3d/ops/ball_query/ball_query.py
mmdet3d/ops/ball_query/ball_query.py
+1
-1
mmdet3d/ops/furthest_point_sample/furthest_point_sample.py
mmdet3d/ops/furthest_point_sample/furthest_point_sample.py
+2
-2
mmdet3d/ops/gather_points/gather_points.py
mmdet3d/ops/gather_points/gather_points.py
+1
-1
mmdet3d/ops/group_points/group_points.py
mmdet3d/ops/group_points/group_points.py
+3
-4
mmdet3d/ops/interpolate/three_interpolate.py
mmdet3d/ops/interpolate/three_interpolate.py
+3
-4
mmdet3d/ops/interpolate/three_nn.py
mmdet3d/ops/interpolate/three_nn.py
+3
-3
mmdet3d/ops/norm.py
mmdet3d/ops/norm.py
+4
-4
mmdet3d/ops/pointnet_modules/point_fp_module.py
mmdet3d/ops/pointnet_modules/point_fp_module.py
+2
-3
mmdet3d/ops/pointnet_modules/point_sa_module.py
mmdet3d/ops/pointnet_modules/point_sa_module.py
+5
-5
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
+4
-4
mmdet3d/ops/spconv/conv.py
mmdet3d/ops/spconv/conv.py
+0
-2
mmdet3d/ops/spconv/modules.py
mmdet3d/ops/spconv/modules.py
+8
-11
mmdet3d/ops/spconv/structure.py
mmdet3d/ops/spconv/structure.py
+3
-3
mmdet3d/ops/spconv/test_utils.py
mmdet3d/ops/spconv/test_utils.py
+3
-3
mmdet3d/utils/collect_env.py
mmdet3d/utils/collect_env.py
+4
-5
tests/test_config.py
tests/test_config.py
+15
-14
No files found.
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
View file @
80b39bd0
import
torch
import
torch
import
torch.
nn
as
nn
from
torch
import
nn
as
nn
from
mmdet3d
import
ops
from
mmdet3d
import
ops
from
mmdet.models.builder
import
ROI_EXTRACTORS
from
mmdet.models.builder
import
ROI_EXTRACTORS
...
@@ -7,7 +7,7 @@ from mmdet.models.builder import ROI_EXTRACTORS
...
@@ -7,7 +7,7 @@ from mmdet.models.builder import ROI_EXTRACTORS
@
ROI_EXTRACTORS
.
register_module
()
@
ROI_EXTRACTORS
.
register_module
()
class
Single3DRoIAwareExtractor
(
nn
.
Module
):
class
Single3DRoIAwareExtractor
(
nn
.
Module
):
"""Point-wise roi-aware Extractor
"""Point-wise roi-aware Extractor
.
Extract Point-wise roi features.
Extract Point-wise roi features.
...
@@ -29,7 +29,7 @@ class Single3DRoIAwareExtractor(nn.Module):
...
@@ -29,7 +29,7 @@ class Single3DRoIAwareExtractor(nn.Module):
return
roi_layers
return
roi_layers
def
forward
(
self
,
feats
,
coordinate
,
batch_inds
,
rois
):
def
forward
(
self
,
feats
,
coordinate
,
batch_inds
,
rois
):
"""Extract point-wise roi features
"""Extract point-wise roi features
.
Args:
Args:
feats (FloatTensor): point-wise features with
feats (FloatTensor): point-wise features with
...
...
mmdet3d/models/voxel_encoders/pillar_encoder.py
View file @
80b39bd0
...
@@ -83,7 +83,7 @@ class PillarFeatureNet(nn.Module):
...
@@ -83,7 +83,7 @@ class PillarFeatureNet(nn.Module):
self
.
point_cloud_range
=
point_cloud_range
self
.
point_cloud_range
=
point_cloud_range
def
forward
(
self
,
features
,
num_points
,
coors
):
def
forward
(
self
,
features
,
num_points
,
coors
):
"""Forward function
"""Forward function
.
Args:
Args:
features (torch.Tensor): Point features or raw points in shape
features (torch.Tensor): Point features or raw points in shape
...
@@ -136,7 +136,7 @@ class PillarFeatureNet(nn.Module):
...
@@ -136,7 +136,7 @@ class PillarFeatureNet(nn.Module):
@
VOXEL_ENCODERS
.
register_module
()
@
VOXEL_ENCODERS
.
register_module
()
class
DynamicPillarFeatureNet
(
PillarFeatureNet
):
class
DynamicPillarFeatureNet
(
PillarFeatureNet
):
"""Pillar Feature Net using dynamic voxelization
"""Pillar Feature Net using dynamic voxelization
.
The network prepares the pillar features and performs forward pass
The network prepares the pillar features and performs forward pass
through PFNLayers. The main difference is that it is used for
through PFNLayers. The main difference is that it is used for
...
@@ -205,7 +205,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
...
@@ -205,7 +205,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
voxel_size
,
point_cloud_range
,
average_points
=
True
)
voxel_size
,
point_cloud_range
,
average_points
=
True
)
def
map_voxel_center_to_point
(
self
,
pts_coors
,
voxel_mean
,
voxel_coors
):
def
map_voxel_center_to_point
(
self
,
pts_coors
,
voxel_mean
,
voxel_coors
):
"""Map the centers of voxels to its corresponding points
"""Map the centers of voxels to its corresponding points
.
Args:
Args:
pts_coors (torch.Tensor): The coordinates of each points, shape
pts_coors (torch.Tensor): The coordinates of each points, shape
...
@@ -244,7 +244,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
...
@@ -244,7 +244,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
return
center_per_point
return
center_per_point
def
forward
(
self
,
features
,
coors
):
def
forward
(
self
,
features
,
coors
):
"""Forward function
"""Forward function
.
Args:
Args:
features (torch.Tensor): Point features or raw points in shape
features (torch.Tensor): Point features or raw points in shape
...
...
mmdet3d/models/voxel_encoders/utils.py
View file @
80b39bd0
...
@@ -28,7 +28,7 @@ def get_paddings_indicator(actual_num, max_num, axis=0):
...
@@ -28,7 +28,7 @@ def get_paddings_indicator(actual_num, max_num, axis=0):
class
VFELayer
(
nn
.
Module
):
class
VFELayer
(
nn
.
Module
):
"""
Voxel Feature Encoder layer.
"""Voxel Feature Encoder layer.
The voxel encoder is composed of a series of these layers.
The voxel encoder is composed of a series of these layers.
This module do not support average pooling and only support to use
This module do not support average pooling and only support to use
...
@@ -59,7 +59,7 @@ class VFELayer(nn.Module):
...
@@ -59,7 +59,7 @@ class VFELayer(nn.Module):
self
.
linear
=
nn
.
Linear
(
in_channels
,
out_channels
,
bias
=
False
)
self
.
linear
=
nn
.
Linear
(
in_channels
,
out_channels
,
bias
=
False
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
"""Forward function
"""Forward function
.
Args:
Args:
inputs (torch.Tensor): Voxels features of shape (N, M, C).
inputs (torch.Tensor): Voxels features of shape (N, M, C).
...
@@ -100,7 +100,7 @@ class VFELayer(nn.Module):
...
@@ -100,7 +100,7 @@ class VFELayer(nn.Module):
class
PFNLayer
(
nn
.
Module
):
class
PFNLayer
(
nn
.
Module
):
"""
Pillar Feature Net Layer.
"""Pillar Feature Net Layer.
The Pillar Feature Net is composed of a series of these layers, but the
The Pillar Feature Net is composed of a series of these layers, but the
PointPillars paper results only used a single PFNLayer.
PointPillars paper results only used a single PFNLayer.
...
@@ -136,7 +136,7 @@ class PFNLayer(nn.Module):
...
@@ -136,7 +136,7 @@ class PFNLayer(nn.Module):
self
.
mode
=
mode
self
.
mode
=
mode
def
forward
(
self
,
inputs
,
num_voxels
=
None
,
aligned_distance
=
None
):
def
forward
(
self
,
inputs
,
num_voxels
=
None
,
aligned_distance
=
None
):
"""Forward function
"""Forward function
.
Args:
Args:
inputs (torch.Tensor): Pillar/Voxel inputs with shape (N, M, C).
inputs (torch.Tensor): Pillar/Voxel inputs with shape (N, M, C).
...
...
mmdet3d/models/voxel_encoders/voxel_encoder.py
View file @
80b39bd0
...
@@ -10,7 +10,7 @@ from .utils import VFELayer, get_paddings_indicator
...
@@ -10,7 +10,7 @@ from .utils import VFELayer, get_paddings_indicator
@
VOXEL_ENCODERS
.
register_module
()
@
VOXEL_ENCODERS
.
register_module
()
class
HardSimpleVFE
(
nn
.
Module
):
class
HardSimpleVFE
(
nn
.
Module
):
"""Simple voxel feature encoder used in SECOND
"""Simple voxel feature encoder used in SECOND
.
It simply averages the values of points in a voxel.
It simply averages the values of points in a voxel.
"""
"""
...
@@ -19,7 +19,7 @@ class HardSimpleVFE(nn.Module):
...
@@ -19,7 +19,7 @@ class HardSimpleVFE(nn.Module):
super
(
HardSimpleVFE
,
self
).
__init__
()
super
(
HardSimpleVFE
,
self
).
__init__
()
def
forward
(
self
,
features
,
num_points
,
coors
):
def
forward
(
self
,
features
,
num_points
,
coors
):
"""Forward function
"""Forward function
.
Args:
Args:
features (torch.Tensor): point features in shape
features (torch.Tensor): point features in shape
...
@@ -39,7 +39,7 @@ class HardSimpleVFE(nn.Module):
...
@@ -39,7 +39,7 @@ class HardSimpleVFE(nn.Module):
@
VOXEL_ENCODERS
.
register_module
()
@
VOXEL_ENCODERS
.
register_module
()
class
DynamicSimpleVFE
(
nn
.
Module
):
class
DynamicSimpleVFE
(
nn
.
Module
):
"""Simple dynamic voxel feature encoder used in DV-SECOND
"""Simple dynamic voxel feature encoder used in DV-SECOND
.
It simply averages the values of points in a voxel.
It simply averages the values of points in a voxel.
But the number of points in a voxel is dynamic and varies.
But the number of points in a voxel is dynamic and varies.
...
@@ -57,7 +57,7 @@ class DynamicSimpleVFE(nn.Module):
...
@@ -57,7 +57,7 @@ class DynamicSimpleVFE(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
self
,
features
,
coors
):
def
forward
(
self
,
features
,
coors
):
"""Forward function
"""Forward function
.
Args:
Args:
features (torch.Tensor): point features in shape
features (torch.Tensor): point features in shape
...
@@ -76,7 +76,7 @@ class DynamicSimpleVFE(nn.Module):
...
@@ -76,7 +76,7 @@ class DynamicSimpleVFE(nn.Module):
@
VOXEL_ENCODERS
.
register_module
()
@
VOXEL_ENCODERS
.
register_module
()
class
DynamicVFE
(
nn
.
Module
):
class
DynamicVFE
(
nn
.
Module
):
"""Dynamic Voxel feature encoder used in DV-SECOND
"""Dynamic Voxel feature encoder used in DV-SECOND
.
It encodes features of voxels and their points. It could also fuse
It encodes features of voxels and their points. It could also fuse
image feature into voxel features in a point-wise manner.
image feature into voxel features in a point-wise manner.
...
@@ -211,7 +211,7 @@ class DynamicVFE(nn.Module):
...
@@ -211,7 +211,7 @@ class DynamicVFE(nn.Module):
points
=
None
,
points
=
None
,
img_feats
=
None
,
img_feats
=
None
,
img_metas
=
None
):
img_metas
=
None
):
"""Forward functions
"""Forward functions
.
Args:
Args:
features (torch.Tensor): Features of voxels, shape is NxC.
features (torch.Tensor): Features of voxels, shape is NxC.
...
@@ -274,7 +274,7 @@ class DynamicVFE(nn.Module):
...
@@ -274,7 +274,7 @@ class DynamicVFE(nn.Module):
@
VOXEL_ENCODERS
.
register_module
()
@
VOXEL_ENCODERS
.
register_module
()
class
HardVFE
(
nn
.
Module
):
class
HardVFE
(
nn
.
Module
):
"""Voxel feature encoder used in DV-SECOND
"""Voxel feature encoder used in DV-SECOND
.
It encodes features of voxels and their points. It could also fuse
It encodes features of voxels and their points. It could also fuse
image feature into voxel features in a point-wise manner.
image feature into voxel features in a point-wise manner.
...
@@ -374,7 +374,7 @@ class HardVFE(nn.Module):
...
@@ -374,7 +374,7 @@ class HardVFE(nn.Module):
coors
,
coors
,
img_feats
=
None
,
img_feats
=
None
,
img_metas
=
None
):
img_metas
=
None
):
"""Forward functions
"""Forward functions
.
Args:
Args:
features (torch.Tensor): Features of voxels, shape is MxNxC.
features (torch.Tensor): Features of voxels, shape is MxNxC.
...
...
mmdet3d/ops/ball_query/ball_query.py
View file @
80b39bd0
...
@@ -5,7 +5,7 @@ from . import ball_query_ext
...
@@ -5,7 +5,7 @@ from . import ball_query_ext
class
BallQuery
(
Function
):
class
BallQuery
(
Function
):
"""Ball Query
"""Ball Query
.
Find nearby points in spherical space.
Find nearby points in spherical space.
"""
"""
...
...
mmdet3d/ops/furthest_point_sample/furthest_point_sample.py
View file @
80b39bd0
...
@@ -7,8 +7,8 @@ from . import furthest_point_sample_ext
...
@@ -7,8 +7,8 @@ from . import furthest_point_sample_ext
class
FurthestPointSampling
(
Function
):
class
FurthestPointSampling
(
Function
):
"""Furthest Point Sampling.
"""Furthest Point Sampling.
Uses iterative furthest point sampling to select a set of
Uses iterative furthest point sampling to select a set of
features whose
features whose
corresponding points have the furthest distance.
corresponding points have the furthest distance.
"""
"""
@
staticmethod
@
staticmethod
...
...
mmdet3d/ops/gather_points/gather_points.py
View file @
80b39bd0
...
@@ -5,7 +5,7 @@ from . import gather_points_ext
...
@@ -5,7 +5,7 @@ from . import gather_points_ext
class
GatherPoints
(
Function
):
class
GatherPoints
(
Function
):
"""Gather Points
"""Gather Points
.
Gather points with given index.
Gather points with given index.
"""
"""
...
...
mmdet3d/ops/group_points/group_points.py
View file @
80b39bd0
from
typing
import
Tuple
import
torch
import
torch
import
torch.
nn
as
nn
from
torch
import
nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
typing
import
Tuple
from
..ball_query
import
ball_query
from
..ball_query
import
ball_query
from
.
import
group_points_ext
from
.
import
group_points_ext
...
@@ -49,7 +48,7 @@ class QueryAndGroup(nn.Module):
...
@@ -49,7 +48,7 @@ class QueryAndGroup(nn.Module):
assert
self
.
uniform_sample
assert
self
.
uniform_sample
def
forward
(
self
,
points_xyz
,
center_xyz
,
features
=
None
):
def
forward
(
self
,
points_xyz
,
center_xyz
,
features
=
None
):
"""forward
"""forward
.
Args:
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
...
...
mmdet3d/ops/interpolate/three_interpolate.py
View file @
80b39bd0
from
typing
import
Tuple
import
torch
import
torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
typing
import
Tuple
from
.
import
interpolate_ext
from
.
import
interpolate_ext
...
@@ -11,7 +10,7 @@ class ThreeInterpolate(Function):
...
@@ -11,7 +10,7 @@ class ThreeInterpolate(Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Performs weighted linear interpolation on 3 features
"""Performs weighted linear interpolation on 3 features
.
Args:
Args:
features (Tensor): (B, C, M) Features descriptors to be
features (Tensor): (B, C, M) Features descriptors to be
...
@@ -40,7 +39,7 @@ class ThreeInterpolate(Function):
...
@@ -40,7 +39,7 @@ class ThreeInterpolate(Function):
def
backward
(
def
backward
(
ctx
,
grad_out
:
torch
.
Tensor
ctx
,
grad_out
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Backward of three interpolate
"""Backward of three interpolate
.
Args:
Args:
grad_out (Tensor): (B, C, N) tensor with gradients of outputs
grad_out (Tensor): (B, C, N) tensor with gradients of outputs
...
...
mmdet3d/ops/interpolate/three_nn.py
View file @
80b39bd0
from
typing
import
Tuple
import
torch
import
torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
typing
import
Tuple
from
.
import
interpolate_ext
from
.
import
interpolate_ext
...
@@ -11,7 +10,8 @@ class ThreeNN(Function):
...
@@ -11,7 +10,8 @@ class ThreeNN(Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
target
:
torch
.
Tensor
,
def
forward
(
ctx
,
target
:
torch
.
Tensor
,
source
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
source
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Find the top-3 nearest neighbors of the target set from the source set.
"""Find the top-3 nearest neighbors of the target set from the source
set.
Args:
Args:
target (Tensor): shape (B, N, 3), points set that needs to
target (Tensor): shape (B, N, 3), points set that needs to
...
...
mmdet3d/ops/norm.py
View file @
80b39bd0
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
mmcv.cnn
import
NORM_LAYERS
from
mmcv.cnn
import
NORM_LAYERS
from
torch
import
distributed
as
dist
from
torch
import
nn
as
nn
from
torch.autograd.function
import
Function
from
torch.autograd.function
import
Function
...
@@ -25,7 +25,7 @@ class AllReduce(Function):
...
@@ -25,7 +25,7 @@ class AllReduce(Function):
@
NORM_LAYERS
.
register_module
(
'naiveSyncBN1d'
)
@
NORM_LAYERS
.
register_module
(
'naiveSyncBN1d'
)
class
NaiveSyncBatchNorm1d
(
nn
.
BatchNorm1d
):
class
NaiveSyncBatchNorm1d
(
nn
.
BatchNorm1d
):
"""Syncronized Batch Normalization for 3D Tensors
"""Syncronized Batch Normalization for 3D Tensors
.
Note:
Note:
This implementation is modified from
This implementation is modified from
...
@@ -70,7 +70,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
...
@@ -70,7 +70,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
@
NORM_LAYERS
.
register_module
(
'naiveSyncBN2d'
)
@
NORM_LAYERS
.
register_module
(
'naiveSyncBN2d'
)
class
NaiveSyncBatchNorm2d
(
nn
.
BatchNorm2d
):
class
NaiveSyncBatchNorm2d
(
nn
.
BatchNorm2d
):
"""Syncronized Batch Normalization for 4D Tensors
"""Syncronized Batch Normalization for 4D Tensors
.
Note:
Note:
This implementation is modified from
This implementation is modified from
...
...
mmdet3d/ops/pointnet_modules/point_fp_module.py
View file @
80b39bd0
from
typing
import
List
import
torch
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn
import
ConvModule
from
torch
import
nn
as
nn
from
typing
import
List
from
mmdet3d.ops
import
three_interpolate
,
three_nn
from
mmdet3d.ops
import
three_interpolate
,
three_nn
...
...
mmdet3d/ops/pointnet_modules/point_sa_module.py
View file @
80b39bd0
from
typing
import
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn
import
ConvModule
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
typing
import
List
from
mmdet3d.ops
import
(
GroupAll
,
QueryAndGroup
,
furthest_point_sample
,
from
mmdet3d.ops
import
(
GroupAll
,
QueryAndGroup
,
furthest_point_sample
,
gather_points
)
gather_points
)
class
PointSAModuleMSG
(
nn
.
Module
):
class
PointSAModuleMSG
(
nn
.
Module
):
"""Point set abstraction module with multi-scale grouping used in Pointnets.
"""Point set abstraction module with multi-scale grouping used in
Pointnets.
Args:
Args:
num_point (int): Number of points.
num_point (int): Number of points.
...
...
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
View file @
80b39bd0
import
mmcv
import
mmcv
import
torch
import
torch
import
torch.
nn
as
nn
from
torch
import
nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
.
import
roiaware_pool3d_ext
from
.
import
roiaware_pool3d_ext
...
@@ -24,7 +24,7 @@ class RoIAwarePool3d(nn.Module):
...
@@ -24,7 +24,7 @@ class RoIAwarePool3d(nn.Module):
self
.
mode
=
pool_method_map
[
mode
]
self
.
mode
=
pool_method_map
[
mode
]
def
forward
(
self
,
rois
,
pts
,
pts_feature
):
def
forward
(
self
,
rois
,
pts
,
pts_feature
):
"""RoIAwarePool3d module forward
"""RoIAwarePool3d module forward
.
Args:
Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
...
@@ -46,7 +46,7 @@ class RoIAwarePool3dFunction(Function):
...
@@ -46,7 +46,7 @@ class RoIAwarePool3dFunction(Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
rois
,
pts
,
pts_feature
,
out_size
,
max_pts_per_voxel
,
def
forward
(
ctx
,
rois
,
pts
,
pts_feature
,
out_size
,
max_pts_per_voxel
,
mode
):
mode
):
"""RoIAwarePool3d function forward
"""RoIAwarePool3d function forward
.
Args:
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
...
@@ -89,7 +89,7 @@ class RoIAwarePool3dFunction(Function):
...
@@ -89,7 +89,7 @@ class RoIAwarePool3dFunction(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
"""RoIAwarePool3d function forward
"""RoIAwarePool3d function forward
.
Args:
Args:
grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
...
...
mmdet3d/ops/spconv/conv.py
View file @
80b39bd0
...
@@ -11,9 +11,7 @@
...
@@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.cnn
import
CONV_LAYERS
from
mmcv.cnn
import
CONV_LAYERS
...
...
mmdet3d/ops/spconv/modules.py
View file @
80b39bd0
...
@@ -11,11 +11,9 @@
...
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
sys
import
sys
from
collections
import
OrderedDict
import
torch
import
torch
from
collections
import
OrderedDict
from
torch
import
nn
from
torch
import
nn
from
.structure
import
SparseConvTensor
from
.structure
import
SparseConvTensor
...
@@ -46,9 +44,8 @@ def _mean_update(vals, m_vals, t):
...
@@ -46,9 +44,8 @@ def _mean_update(vals, m_vals, t):
class
SparseModule
(
nn
.
Module
):
class
SparseModule
(
nn
.
Module
):
""" place holder,
"""place holder, All module subclass from this will take sptensor in
All module subclass from this will take sptensor in SparseSequential.
SparseSequential."""
"""
pass
pass
...
@@ -140,7 +137,9 @@ class SparseSequential(SparseModule):
...
@@ -140,7 +137,9 @@ class SparseSequential(SparseModule):
return
input
return
input
def
fused
(
self
):
def
fused
(
self
):
"""don't use this. no effect.
"""don't use this.
no effect.
"""
"""
from
.conv
import
SparseConvolution
from
.conv
import
SparseConvolution
mods
=
[
v
for
k
,
v
in
self
.
_modules
.
items
()]
mods
=
[
v
for
k
,
v
in
self
.
_modules
.
items
()]
...
@@ -189,16 +188,14 @@ class SparseSequential(SparseModule):
...
@@ -189,16 +188,14 @@ class SparseSequential(SparseModule):
class
ToDense
(
SparseModule
):
class
ToDense
(
SparseModule
):
"""convert SparseConvTensor to NCHW dense tensor.
"""convert SparseConvTensor to NCHW dense tensor."""
"""
def
forward
(
self
,
x
:
SparseConvTensor
):
def
forward
(
self
,
x
:
SparseConvTensor
):
return
x
.
dense
()
return
x
.
dense
()
class
RemoveGrid
(
SparseModule
):
class
RemoveGrid
(
SparseModule
):
"""remove pre-allocated grid buffer.
"""remove pre-allocated grid buffer."""
"""
def
forward
(
self
,
x
:
SparseConvTensor
):
def
forward
(
self
,
x
:
SparseConvTensor
):
x
.
grid
=
None
x
.
grid
=
None
...
...
mmdet3d/ops/spconv/structure.py
View file @
80b39bd0
...
@@ -4,9 +4,9 @@ import torch
...
@@ -4,9 +4,9 @@ import torch
def
scatter_nd
(
indices
,
updates
,
shape
):
def
scatter_nd
(
indices
,
updates
,
shape
):
"""pytorch edition of tensorflow scatter_nd.
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
this function don't contain except handle code. so use this carefully when
in tensorflow.
indice repeats, don't support repeat add which is supported
in tensorflow.
"""
"""
ret
=
torch
.
zeros
(
*
shape
,
dtype
=
updates
.
dtype
,
device
=
updates
.
device
)
ret
=
torch
.
zeros
(
*
shape
,
dtype
=
updates
.
dtype
,
device
=
updates
.
device
)
ndim
=
indices
.
shape
[
-
1
]
ndim
=
indices
.
shape
[
-
1
]
...
...
mmdet3d/ops/spconv/test_utils.py
View file @
80b39bd0
...
@@ -11,10 +11,8 @@
...
@@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
unittest
import
numpy
as
np
import
numpy
as
np
import
unittest
class
TestCase
(
unittest
.
TestCase
):
class
TestCase
(
unittest
.
TestCase
):
...
@@ -26,6 +24,7 @@ class TestCase(unittest.TestCase):
...
@@ -26,6 +24,7 @@ class TestCase(unittest.TestCase):
def
assertAllEqual
(
self
,
a
,
b
):
def
assertAllEqual
(
self
,
a
,
b
):
"""Asserts that two numpy arrays have the same values.
"""Asserts that two numpy arrays have the same values.
Args:
Args:
a: the expected numpy ndarray or anything can be converted to one.
a: the expected numpy ndarray or anything can be converted to one.
b: the actual numpy ndarray or anything can be converted to one.
b: the actual numpy ndarray or anything can be converted to one.
...
@@ -56,6 +55,7 @@ class TestCase(unittest.TestCase):
...
@@ -56,6 +55,7 @@ class TestCase(unittest.TestCase):
def
assertAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
):
def
assertAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
):
"""Asserts that two numpy arrays, or dicts of same, have near values.
"""Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts.
This does not support nested dicts.
Args:
Args:
a: The expected numpy ndarray (or anything can be converted to one), or
a: The expected numpy ndarray (or anything can be converted to one), or
...
...
mmdet3d/utils/collect_env.py
View file @
80b39bd0
import
os.path
as
osp
import
subprocess
import
sys
from
collections
import
defaultdict
import
cv2
import
cv2
import
mmcv
import
mmcv
import
subprocess
import
sys
import
torch
import
torch
import
torchvision
import
torchvision
from
collections
import
defaultdict
from
os
import
path
as
osp
import
mmdet
import
mmdet
import
mmdet3d
import
mmdet3d
...
...
tests/test_config.py
View file @
80b39bd0
...
@@ -4,7 +4,7 @@ from mmdet.core import BitmapMasks, PolygonMasks
...
@@ -4,7 +4,7 @@ from mmdet.core import BitmapMasks, PolygonMasks
def
_get_config_directory
():
def
_get_config_directory
():
"""
Find the predefined detector config directory
"""
"""Find the predefined detector config directory
.
"""
try
:
try
:
# Assume we are running in the source mmdetection repo
# Assume we are running in the source mmdetection repo
repo_dpath
=
dirname
(
dirname
(
__file__
))
repo_dpath
=
dirname
(
dirname
(
__file__
))
...
@@ -19,10 +19,10 @@ def _get_config_directory():
...
@@ -19,10 +19,10 @@ def _get_config_directory():
def
test_config_build_detector
():
def
test_config_build_detector
():
"""
"""Test that all detection models defined in the configs can be
Test that all detection models defined in the configs can be initialized.
initialized."""
"""
from
mmcv
import
Config
from
mmcv
import
Config
from
mmdet3d.models
import
build_detector
from
mmdet3d.models
import
build_detector
config_dpath
=
_get_config_directory
()
config_dpath
=
_get_config_directory
()
...
@@ -74,10 +74,10 @@ def test_config_build_detector():
...
@@ -74,10 +74,10 @@ def test_config_build_detector():
def
test_config_build_pipeline
():
def
test_config_build_pipeline
():
"""
"""Test that all detection models defined in the configs can be
Test that all detection models defined in the configs can be initialized.
initialized."""
"""
from
mmcv
import
Config
from
mmcv
import
Config
from
mmdet3d.datasets.pipelines
import
Compose
from
mmdet3d.datasets.pipelines
import
Compose
config_dpath
=
_get_config_directory
()
config_dpath
=
_get_config_directory
()
...
@@ -102,14 +102,15 @@ def test_config_build_pipeline():
...
@@ -102,14 +102,15 @@ def test_config_build_pipeline():
def
test_config_data_pipeline
():
def
test_config_data_pipeline
():
"""
"""
Test whether the data pipeline is valid and can process corner cases.
Test whether the data pipeline is valid and can process corner cases.
CommandLine:
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
"""
import
numpy
as
np
from
mmcv
import
Config
from
mmcv
import
Config
from
mmdet3d.datasets.pipelines
import
Compose
from
mmdet3d.datasets.pipelines
import
Compose
import
numpy
as
np
config_dpath
=
_get_config_directory
()
config_dpath
=
_get_config_directory
()
print
(
'Found config_dpath = {!r}'
.
format
(
config_dpath
))
print
(
'Found config_dpath = {!r}'
.
format
(
config_dpath
))
...
@@ -262,7 +263,7 @@ def _check_roi_head(config, head):
...
@@ -262,7 +263,7 @@ def _check_roi_head(config, head):
def
_check_roi_extractor
(
config
,
roi_extractor
,
prev_roi_extractor
=
None
):
def
_check_roi_extractor
(
config
,
roi_extractor
,
prev_roi_extractor
=
None
):
import
torch.
nn
as
nn
from
torch
import
nn
as
nn
if
isinstance
(
roi_extractor
,
nn
.
ModuleList
):
if
isinstance
(
roi_extractor
,
nn
.
ModuleList
):
if
prev_roi_extractor
:
if
prev_roi_extractor
:
prev_roi_extractor
=
prev_roi_extractor
[
0
]
prev_roi_extractor
=
prev_roi_extractor
[
0
]
...
@@ -289,7 +290,7 @@ def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
...
@@ -289,7 +290,7 @@ def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
def
_check_mask_head
(
mask_cfg
,
mask_head
):
def
_check_mask_head
(
mask_cfg
,
mask_head
):
import
torch.
nn
as
nn
from
torch
import
nn
as
nn
if
isinstance
(
mask_cfg
,
list
):
if
isinstance
(
mask_cfg
,
list
):
for
single_mask_cfg
,
single_mask_head
in
zip
(
mask_cfg
,
mask_head
):
for
single_mask_cfg
,
single_mask_head
in
zip
(
mask_cfg
,
mask_head
):
_check_mask_head
(
single_mask_cfg
,
single_mask_head
)
_check_mask_head
(
single_mask_cfg
,
single_mask_head
)
...
@@ -307,7 +308,7 @@ def _check_mask_head(mask_cfg, mask_head):
...
@@ -307,7 +308,7 @@ def _check_mask_head(mask_cfg, mask_head):
def
_check_bbox_head
(
bbox_cfg
,
bbox_head
):
def
_check_bbox_head
(
bbox_cfg
,
bbox_head
):
import
torch.
nn
as
nn
from
torch
import
nn
as
nn
if
isinstance
(
bbox_cfg
,
list
):
if
isinstance
(
bbox_cfg
,
list
):
for
single_bbox_cfg
,
single_bbox_head
in
zip
(
bbox_cfg
,
bbox_head
):
for
single_bbox_cfg
,
single_bbox_head
in
zip
(
bbox_cfg
,
bbox_head
):
_check_bbox_head
(
single_bbox_cfg
,
single_bbox_head
)
_check_bbox_head
(
single_bbox_cfg
,
single_bbox_head
)
...
@@ -357,7 +358,7 @@ def _check_parta2_roi_extractor(config, roi_extractor):
...
@@ -357,7 +358,7 @@ def _check_parta2_roi_extractor(config, roi_extractor):
def
_check_parta2_bbox_head
(
bbox_cfg
,
bbox_head
):
def
_check_parta2_bbox_head
(
bbox_cfg
,
bbox_head
):
import
torch.
nn
as
nn
from
torch
import
nn
as
nn
if
isinstance
(
bbox_cfg
,
list
):
if
isinstance
(
bbox_cfg
,
list
):
for
single_bbox_cfg
,
single_bbox_head
in
zip
(
bbox_cfg
,
bbox_head
):
for
single_bbox_cfg
,
single_bbox_head
in
zip
(
bbox_cfg
,
bbox_head
):
_check_bbox_head
(
single_bbox_cfg
,
single_bbox_head
)
_check_bbox_head
(
single_bbox_cfg
,
single_bbox_head
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment