Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
80b39bd0
"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "79f1629753124e9291eb3b02cf414b3e6f7c363e"
Commit
80b39bd0
authored
Jul 04, 2020
by
zhangwenwei
Browse files
Reformat docstrings in code
parent
64d7fbc2
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
59 additions
and
70 deletions
+59
-70
mmdet3d/core/post_processing/box3d_nms.py
mmdet3d/core/post_processing/box3d_nms.py
+1
-1
mmdet3d/core/visualizer/show_result.py
mmdet3d/core/visualizer/show_result.py
+3
-4
mmdet3d/core/voxel/builder.py
mmdet3d/core/voxel/builder.py
+1
-1
mmdet3d/core/voxel/voxel_generator.py
mmdet3d/core/voxel/voxel_generator.py
+2
-2
mmdet3d/datasets/custom_3d.py
mmdet3d/datasets/custom_3d.py
+6
-8
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+4
-5
mmdet3d/datasets/lyft_dataset.py
mmdet3d/datasets/lyft_dataset.py
+3
-4
mmdet3d/datasets/nuscenes_dataset.py
mmdet3d/datasets/nuscenes_dataset.py
+3
-4
mmdet3d/datasets/pipelines/data_augment_utils.py
mmdet3d/datasets/pipelines/data_augment_utils.py
+12
-13
mmdet3d/datasets/pipelines/dbsampler.py
mmdet3d/datasets/pipelines/dbsampler.py
+1
-2
mmdet3d/datasets/pipelines/loading.py
mmdet3d/datasets/pipelines/loading.py
+3
-3
mmdet3d/datasets/pipelines/test_time_aug.py
mmdet3d/datasets/pipelines/test_time_aug.py
+2
-3
mmdet3d/datasets/pipelines/transforms_3d.py
mmdet3d/datasets/pipelines/transforms_3d.py
+4
-4
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+2
-3
mmdet3d/datasets/sunrgbd_dataset.py
mmdet3d/datasets/sunrgbd_dataset.py
+2
-3
mmdet3d/models/backbones/pointnet2_sa_ssg.py
mmdet3d/models/backbones/pointnet2_sa_ssg.py
+1
-1
mmdet3d/models/backbones/second.py
mmdet3d/models/backbones/second.py
+2
-2
mmdet3d/models/dense_heads/anchor3d_head.py
mmdet3d/models/dense_heads/anchor3d_head.py
+2
-2
mmdet3d/models/dense_heads/free_anchor3d_head.py
mmdet3d/models/dense_heads/free_anchor3d_head.py
+4
-4
mmdet3d/models/dense_heads/parta2_rpn_head.py
mmdet3d/models/dense_heads/parta2_rpn_head.py
+1
-1
No files found.
mmdet3d/core/post_processing/box3d_nms.py
View file @
80b39bd0
...
@@ -10,7 +10,7 @@ def box3d_multiclass_nms(mlvl_bboxes,
...
@@ -10,7 +10,7 @@ def box3d_multiclass_nms(mlvl_bboxes,
max_num
,
max_num
,
cfg
,
cfg
,
mlvl_dir_scores
=
None
):
mlvl_dir_scores
=
None
):
"""Multi-class nms for 3D boxes
"""Multi-class nms for 3D boxes
.
Args:
Args:
mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M).
mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M).
...
...
mmdet3d/core/visualizer/show_result.py
View file @
80b39bd0
import
os.path
as
osp
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
trimesh
import
trimesh
from
os
import
path
as
osp
def
_write_ply
(
points
,
out_filename
):
def
_write_ply
(
points
,
out_filename
):
"""Write points into ply format for meshlab visualization
"""Write points into ply format for meshlab visualization
.
Args:
Args:
points (np.ndarray): Points in shape (N, dim).
points (np.ndarray): Points in shape (N, dim).
...
@@ -28,7 +27,7 @@ def _write_ply(points, out_filename):
...
@@ -28,7 +27,7 @@ def _write_ply(points, out_filename):
def
_write_oriented_bbox
(
scene_bbox
,
out_filename
):
def
_write_oriented_bbox
(
scene_bbox
,
out_filename
):
"""Export oriented (around Z axis) scene bbox to meshes
"""Export oriented (around Z axis) scene bbox to meshes
.
Args:
Args:
scene_bbox(list[ndarray] or ndarray): xyz pos of center and
scene_bbox(list[ndarray] or ndarray): xyz pos of center and
...
...
mmdet3d/core/voxel/builder.py
View file @
80b39bd0
...
@@ -4,7 +4,7 @@ from . import voxel_generator
...
@@ -4,7 +4,7 @@ from . import voxel_generator
def
build_voxel_generator
(
cfg
,
**
kwargs
):
def
build_voxel_generator
(
cfg
,
**
kwargs
):
"""Builder of voxel generator"""
"""Builder of voxel generator
.
"""
if
isinstance
(
cfg
,
voxel_generator
.
VoxelGenerator
):
if
isinstance
(
cfg
,
voxel_generator
.
VoxelGenerator
):
return
cfg
return
cfg
elif
isinstance
(
cfg
,
dict
):
elif
isinstance
(
cfg
,
dict
):
...
...
mmdet3d/core/voxel/voxel_generator.py
View file @
80b39bd0
...
@@ -3,7 +3,7 @@ import numpy as np
...
@@ -3,7 +3,7 @@ import numpy as np
class
VoxelGenerator
(
object
):
class
VoxelGenerator
(
object
):
"""Voxel generator in numpy implementation
"""Voxel generator in numpy implementation
.
Args:
Args:
voxel_size (list[float]): Size of a single voxel
voxel_size (list[float]): Size of a single voxel
...
@@ -33,7 +33,7 @@ class VoxelGenerator(object):
...
@@ -33,7 +33,7 @@ class VoxelGenerator(object):
self
.
_grid_size
=
grid_size
self
.
_grid_size
=
grid_size
def
generate
(
self
,
points
):
def
generate
(
self
,
points
):
"""Generate voxels given points"""
"""Generate voxels given points
.
"""
return
points_to_voxel
(
points
,
self
.
_voxel_size
,
return
points_to_voxel
(
points
,
self
.
_voxel_size
,
self
.
_point_cloud_range
,
self
.
_max_num_points
,
self
.
_point_cloud_range
,
self
.
_max_num_points
,
True
,
self
.
_max_voxels
)
True
,
self
.
_max_voxels
)
...
...
mmdet3d/datasets/custom_3d.py
View file @
80b39bd0
import
os.path
as
osp
import
tempfile
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
tempfile
from
os
import
path
as
osp
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets
import
DATASETS
...
@@ -12,7 +11,7 @@ from .pipelines import Compose
...
@@ -12,7 +11,7 @@ from .pipelines import Compose
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
Custom3DDataset
(
Dataset
):
class
Custom3DDataset
(
Dataset
):
"""Customized 3D dataset
"""Customized 3D dataset
.
This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
dataset.
dataset.
...
@@ -179,7 +178,7 @@ class Custom3DDataset(Dataset):
...
@@ -179,7 +178,7 @@ class Custom3DDataset(Dataset):
from
mmdet3d.core.evaluation
import
indoor_eval
from
mmdet3d.core.evaluation
import
indoor_eval
assert
isinstance
(
assert
isinstance
(
results
,
list
),
f
'Expect results to be list, got
{
type
(
results
)
}
.'
results
,
list
),
f
'Expect results to be list, got
{
type
(
results
)
}
.'
assert
len
(
results
)
>
0
,
f
'Expect length of results > 0.'
assert
len
(
results
)
>
0
,
'Expect length of results > 0.'
assert
len
(
results
)
==
len
(
self
.
data_infos
)
assert
len
(
results
)
==
len
(
self
.
data_infos
)
assert
isinstance
(
assert
isinstance
(
results
[
0
],
dict
results
[
0
],
dict
...
@@ -220,8 +219,7 @@ class Custom3DDataset(Dataset):
...
@@ -220,8 +219,7 @@ class Custom3DDataset(Dataset):
"""Set flag according to image aspect ratio.
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
otherwise group 0. In 3D datasets, they are all the same, thus are all
In 3D datasets, they are all the same, thus are all zeros
zeros
"""
"""
self
.
flag
=
np
.
zeros
(
len
(
self
),
dtype
=
np
.
uint8
)
self
.
flag
=
np
.
zeros
(
len
(
self
),
dtype
=
np
.
uint8
)
mmdet3d/datasets/kitti_dataset.py
View file @
80b39bd0
import
copy
import
copy
import
os
import
os.path
as
osp
import
tempfile
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
os
import
tempfile
import
torch
import
torch
from
mmcv.utils
import
print_log
from
mmcv.utils
import
print_log
from
os
import
path
as
osp
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets
import
DATASETS
from
..core
import
show_result
from
..core
import
show_result
...
@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset
...
@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
KittiDataset
(
Custom3DDataset
):
class
KittiDataset
(
Custom3DDataset
):
"""KITTI Dataset
"""KITTI Dataset
.
This class serves as the API for experiments on the KITTI Dataset.
This class serves as the API for experiments on the KITTI Dataset.
...
...
mmdet3d/datasets/lyft_dataset.py
View file @
80b39bd0
import
os.path
as
osp
import
tempfile
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
import
tempfile
from
lyft_dataset_sdk.lyftdataset
import
LyftDataset
as
Lyft
from
lyft_dataset_sdk.lyftdataset
import
LyftDataset
as
Lyft
from
lyft_dataset_sdk.utils.data_classes
import
Box
as
LyftBox
from
lyft_dataset_sdk.utils.data_classes
import
Box
as
LyftBox
from
os
import
path
as
osp
from
pyquaternion
import
Quaternion
from
pyquaternion
import
Quaternion
from
mmdet3d.core.evaluation.lyft_eval
import
lyft_eval
from
mmdet3d.core.evaluation.lyft_eval
import
lyft_eval
...
@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset
...
@@ -16,7 +15,7 @@ from .custom_3d import Custom3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
LyftDataset
(
Custom3DDataset
):
class
LyftDataset
(
Custom3DDataset
):
"""Lyft Dataset
"""Lyft Dataset
.
This class serves as the API for experiments on the Lyft Dataset.
This class serves as the API for experiments on the Lyft Dataset.
...
...
mmdet3d/datasets/nuscenes_dataset.py
View file @
80b39bd0
import
os.path
as
osp
import
tempfile
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
pyquaternion
import
pyquaternion
import
tempfile
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
os
import
path
as
osp
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets
import
DATASETS
from
..core
import
show_result
from
..core
import
show_result
...
@@ -14,7 +13,7 @@ from .custom_3d import Custom3DDataset
...
@@ -14,7 +13,7 @@ from .custom_3d import Custom3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
NuScenesDataset
(
Custom3DDataset
):
class
NuScenesDataset
(
Custom3DDataset
):
"""NuScenes Dataset
"""NuScenes Dataset
.
This class serves as the API for experiments on the NuScenes Dataset.
This class serves as the API for experiments on the NuScenes Dataset.
...
...
mmdet3d/datasets/pipelines/data_augment_utils.py
View file @
80b39bd0
import
warnings
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
import
warnings
from
numba.errors
import
NumbaPerformanceWarning
from
numba.errors
import
NumbaPerformanceWarning
from
mmdet3d.core.bbox
import
box_np_ops
from
mmdet3d.core.bbox
import
box_np_ops
...
@@ -44,11 +43,11 @@ def box_collision_test(boxes, qboxes, clockwise=True):
...
@@ -44,11 +43,11 @@ def box_collision_test(boxes, qboxes, clockwise=True):
max
(
boxes_standup
[
i
,
1
],
qboxes_standup
[
j
,
1
]))
max
(
boxes_standup
[
i
,
1
],
qboxes_standup
[
j
,
1
]))
if
ih
>
0
:
if
ih
>
0
:
for
k
in
range
(
4
):
for
k
in
range
(
4
):
for
l
in
range
(
4
):
for
box_
l
in
range
(
4
):
A
=
lines_boxes
[
i
,
k
,
0
]
A
=
lines_boxes
[
i
,
k
,
0
]
B
=
lines_boxes
[
i
,
k
,
1
]
B
=
lines_boxes
[
i
,
k
,
1
]
C
=
lines_qboxes
[
j
,
l
,
0
]
C
=
lines_qboxes
[
j
,
box_
l
,
0
]
D
=
lines_qboxes
[
j
,
l
,
1
]
D
=
lines_qboxes
[
j
,
box_
l
,
1
]
acd
=
(
D
[
1
]
-
A
[
1
])
*
(
C
[
0
]
-
acd
=
(
D
[
1
]
-
A
[
1
])
*
(
C
[
0
]
-
A
[
0
])
>
(
C
[
1
]
-
A
[
1
])
*
(
A
[
0
])
>
(
C
[
1
]
-
A
[
1
])
*
(
D
[
0
]
-
A
[
0
])
D
[
0
]
-
A
[
0
])
...
@@ -71,15 +70,15 @@ def box_collision_test(boxes, qboxes, clockwise=True):
...
@@ -71,15 +70,15 @@ def box_collision_test(boxes, qboxes, clockwise=True):
# now check complete overlap.
# now check complete overlap.
# box overlap qbox:
# box overlap qbox:
box_overlap_qbox
=
True
box_overlap_qbox
=
True
for
l
in
range
(
4
):
# point l in qboxes
for
box_
l
in
range
(
4
):
# point l in qboxes
for
k
in
range
(
4
):
# corner k in boxes
for
k
in
range
(
4
):
# corner k in boxes
vec
=
boxes
[
i
,
k
]
-
boxes
[
i
,
(
k
+
1
)
%
4
]
vec
=
boxes
[
i
,
k
]
-
boxes
[
i
,
(
k
+
1
)
%
4
]
if
clockwise
:
if
clockwise
:
vec
=
-
vec
vec
=
-
vec
cross
=
vec
[
1
]
*
(
cross
=
vec
[
1
]
*
(
boxes
[
i
,
k
,
0
]
-
qboxes
[
j
,
l
,
0
])
boxes
[
i
,
k
,
0
]
-
qboxes
[
j
,
box_
l
,
0
])
cross
-=
vec
[
0
]
*
(
cross
-=
vec
[
0
]
*
(
boxes
[
i
,
k
,
1
]
-
qboxes
[
j
,
l
,
1
])
boxes
[
i
,
k
,
1
]
-
qboxes
[
j
,
box_
l
,
1
])
if
cross
>=
0
:
if
cross
>=
0
:
box_overlap_qbox
=
False
box_overlap_qbox
=
False
break
break
...
@@ -88,15 +87,15 @@ def box_collision_test(boxes, qboxes, clockwise=True):
...
@@ -88,15 +87,15 @@ def box_collision_test(boxes, qboxes, clockwise=True):
if
box_overlap_qbox
is
False
:
if
box_overlap_qbox
is
False
:
qbox_overlap_box
=
True
qbox_overlap_box
=
True
for
l
in
range
(
4
):
# point l in boxes
for
box_
l
in
range
(
4
):
# point
box_
l in boxes
for
k
in
range
(
4
):
# corner k in qboxes
for
k
in
range
(
4
):
# corner k in qboxes
vec
=
qboxes
[
j
,
k
]
-
qboxes
[
j
,
(
k
+
1
)
%
4
]
vec
=
qboxes
[
j
,
k
]
-
qboxes
[
j
,
(
k
+
1
)
%
4
]
if
clockwise
:
if
clockwise
:
vec
=
-
vec
vec
=
-
vec
cross
=
vec
[
1
]
*
(
cross
=
vec
[
1
]
*
(
qboxes
[
j
,
k
,
0
]
-
boxes
[
i
,
l
,
0
])
qboxes
[
j
,
k
,
0
]
-
boxes
[
i
,
box_
l
,
0
])
cross
-=
vec
[
0
]
*
(
cross
-=
vec
[
0
]
*
(
qboxes
[
j
,
k
,
1
]
-
boxes
[
i
,
l
,
1
])
qboxes
[
j
,
k
,
1
]
-
boxes
[
i
,
box_
l
,
1
])
if
cross
>=
0
:
#
if
cross
>=
0
:
#
qbox_overlap_box
=
False
qbox_overlap_box
=
False
break
break
...
@@ -264,8 +263,8 @@ def noise_per_object_v3_(gt_boxes,
...
@@ -264,8 +263,8 @@ def noise_per_object_v3_(gt_boxes,
center_noise_std
=
1.0
,
center_noise_std
=
1.0
,
global_random_rot_range
=
np
.
pi
/
4
,
global_random_rot_range
=
np
.
pi
/
4
,
num_try
=
100
):
num_try
=
100
):
"""random rotate or remove each groundtrutn independently.
"""random rotate or remove each groundtrutn independently.
use kitti viewer
use kitti viewer
to test this function points_transform_
to test this function points_transform_
Args:
Args:
gt_boxes: [N, 7], gt box in lidar.points_transform_
gt_boxes: [N, 7], gt box in lidar.points_transform_
...
...
mmdet3d/datasets/pipelines/dbsampler.py
View file @
80b39bd0
import
copy
import
copy
import
numpy
as
np
import
os
import
os
import
pickle
import
pickle
import
numpy
as
np
from
mmdet3d.core.bbox
import
box_np_ops
from
mmdet3d.core.bbox
import
box_np_ops
from
mmdet3d.datasets.pipelines
import
data_augment_utils
from
mmdet3d.datasets.pipelines
import
data_augment_utils
from
..registry
import
OBJECTSAMPLERS
from
..registry
import
OBJECTSAMPLERS
...
...
mmdet3d/datasets/pipelines/loading.py
View file @
80b39bd0
...
@@ -7,7 +7,7 @@ from mmdet.datasets.pipelines import LoadAnnotations
...
@@ -7,7 +7,7 @@ from mmdet.datasets.pipelines import LoadAnnotations
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
LoadMultiViewImageFromFiles
(
object
):
class
LoadMultiViewImageFromFiles
(
object
):
"""
Load multi channel images from a list of separate channel files.
"""Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames
Expects results['img_filename'] to be a list of filenames
"""
"""
...
@@ -43,7 +43,7 @@ class LoadMultiViewImageFromFiles(object):
...
@@ -43,7 +43,7 @@ class LoadMultiViewImageFromFiles(object):
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
LoadPointsFromMultiSweeps
(
object
):
class
LoadPointsFromMultiSweeps
(
object
):
"""Load points from multiple sweeps
"""Load points from multiple sweeps
.
This is usually used for nuScenes dataset to utilize previous sweeps.
This is usually used for nuScenes dataset to utilize previous sweeps.
...
@@ -143,7 +143,7 @@ class PointSegClassMapping(object):
...
@@ -143,7 +143,7 @@ class PointSegClassMapping(object):
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
NormalizePointsColor
(
object
):
class
NormalizePointsColor
(
object
):
"""Normalize color of points
"""Normalize color of points
.
Normalize color of the points.
Normalize color of the points.
...
...
mmdet3d/datasets/pipelines/test_time_aug.py
View file @
80b39bd0
import
mmcv
import
warnings
import
warnings
from
copy
import
deepcopy
from
copy
import
deepcopy
import
mmcv
from
mmdet.datasets.builder
import
PIPELINES
from
mmdet.datasets.builder
import
PIPELINES
from
mmdet.datasets.pipelines
import
Compose
from
mmdet.datasets.pipelines
import
Compose
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
MultiScaleFlipAug3D
(
object
):
class
MultiScaleFlipAug3D
(
object
):
"""Test-time augmentation with multiple scales and flipping
"""Test-time augmentation with multiple scales and flipping
.
Args:
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
transforms (list[dict]): Transforms to apply in each augmentation.
...
...
mmdet3d/datasets/pipelines/transforms_3d.py
View file @
80b39bd0
...
@@ -91,7 +91,7 @@ class RandomFlip3D(RandomFlip):
...
@@ -91,7 +91,7 @@ class RandomFlip3D(RandomFlip):
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
ObjectSample
(
object
):
class
ObjectSample
(
object
):
"""Sample GT objects to the data
"""Sample GT objects to the data
.
Args:
Args:
db_sampler (dict): Config dict of the database sampler.
db_sampler (dict): Config dict of the database sampler.
...
@@ -168,7 +168,7 @@ class ObjectSample(object):
...
@@ -168,7 +168,7 @@ class ObjectSample(object):
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
ObjectNoise
(
object
):
class
ObjectNoise
(
object
):
"""Apply noise to each GT objects in the scene
"""Apply noise to each GT objects in the scene
.
Args:
Args:
translation_std (list, optional): Standard deviation of the
translation_std (list, optional): Standard deviation of the
...
@@ -221,7 +221,7 @@ class ObjectNoise(object):
...
@@ -221,7 +221,7 @@ class ObjectNoise(object):
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
GlobalRotScaleTrans
(
object
):
class
GlobalRotScaleTrans
(
object
):
"""Apply global rotation, scaling and translation to a 3D scene
"""Apply global rotation, scaling and translation to a 3D scene
.
Args:
Args:
rot_range (list[float]): Range of rotation angle.
rot_range (list[float]): Range of rotation angle.
...
@@ -374,7 +374,7 @@ class PointsRangeFilter(object):
...
@@ -374,7 +374,7 @@ class PointsRangeFilter(object):
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
ObjectNameFilter
(
object
):
class
ObjectNameFilter
(
object
):
"""Filter GT objects by their names
"""Filter GT objects by their names
.
Args:
Args:
classes (list[str]): list of class names to be kept for training
classes (list[str]): list of class names to be kept for training
...
...
mmdet3d/datasets/scannet_dataset.py
View file @
80b39bd0
import
os.path
as
osp
import
numpy
as
np
import
numpy
as
np
from
os
import
path
as
osp
from
mmdet3d.core
import
show_result
from
mmdet3d.core
import
show_result
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
...
@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset
...
@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
ScanNetDataset
(
Custom3DDataset
):
class
ScanNetDataset
(
Custom3DDataset
):
"""ScanNet Dataset
"""ScanNet Dataset
.
This class serves as the API for experiments on the ScanNet Dataset.
This class serves as the API for experiments on the ScanNet Dataset.
...
...
mmdet3d/datasets/sunrgbd_dataset.py
View file @
80b39bd0
import
os.path
as
osp
import
numpy
as
np
import
numpy
as
np
from
os
import
path
as
osp
from
mmdet3d.core
import
show_result
from
mmdet3d.core
import
show_result
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
...
@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset
...
@@ -10,7 +9,7 @@ from .custom_3d import Custom3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
SUNRGBDDataset
(
Custom3DDataset
):
class
SUNRGBDDataset
(
Custom3DDataset
):
"""SUNRGBD Dataset
"""SUNRGBD Dataset
.
This class serves as the API for experiments on the SUNRGBD Dataset.
This class serves as the API for experiments on the SUNRGBD Dataset.
...
...
mmdet3d/models/backbones/pointnet2_sa_ssg.py
View file @
80b39bd0
import
torch
import
torch
import
torch.nn
as
nn
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner
import
load_checkpoint
from
torch
import
nn
as
nn
from
mmdet3d.ops
import
PointFPModule
,
PointSAModule
from
mmdet3d.ops
import
PointFPModule
,
PointSAModule
from
mmdet.models
import
BACKBONES
from
mmdet.models
import
BACKBONES
...
...
mmdet3d/models/backbones/second.py
View file @
80b39bd0
import
torch.nn
as
nn
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner
import
load_checkpoint
from
torch
import
nn
as
nn
from
mmdet.models
import
BACKBONES
from
mmdet.models
import
BACKBONES
@
BACKBONES
.
register_module
()
@
BACKBONES
.
register_module
()
class
SECOND
(
nn
.
Module
):
class
SECOND
(
nn
.
Module
):
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet
.
Args:
Args:
in_channels (int): Input channels
in_channels (int): Input channels
...
...
mmdet3d/models/dense_heads/anchor3d_head.py
View file @
80b39bd0
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
bias_init_with_prob
,
normal_init
from
mmcv.cnn
import
bias_init_with_prob
,
normal_init
from
torch
import
nn
as
nn
from
mmdet3d.core
import
(
PseudoSampler
,
box3d_multiclass_nms
,
limit_period
,
from
mmdet3d.core
import
(
PseudoSampler
,
box3d_multiclass_nms
,
limit_period
,
xywhr2xyxyr
)
xywhr2xyxyr
)
...
@@ -244,7 +244,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
...
@@ -244,7 +244,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
@
staticmethod
@
staticmethod
def
add_sin_difference
(
boxes1
,
boxes2
):
def
add_sin_difference
(
boxes1
,
boxes2
):
"""Convert the rotation difference to difference in sine function
"""Convert the rotation difference to difference in sine function
.
Args:
Args:
boxes1 (torch.Tensor): shape (NxC), where C>=7 and
boxes1 (torch.Tensor): shape (NxC), where C>=7 and
...
...
mmdet3d/models/dense_heads/free_anchor3d_head.py
View file @
80b39bd0
import
torch
import
torch
import
torch.nn
.
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.core.bbox
import
bbox_overlaps_nearest_3d
from
mmdet3d.core.bbox
import
bbox_overlaps_nearest_3d
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
...
@@ -9,7 +9,7 @@ from .train_mixins import get_direction_target
...
@@ -9,7 +9,7 @@ from .train_mixins import get_direction_target
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
class
FreeAnchor3DHead
(
Anchor3DHead
):
class
FreeAnchor3DHead
(
Anchor3DHead
):
"""`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection
"""`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection
.
Note:
Note:
This implementation is directly modified from the `mmdet implementation
This implementation is directly modified from the `mmdet implementation
...
@@ -237,7 +237,7 @@ class FreeAnchor3DHead(Anchor3DHead):
...
@@ -237,7 +237,7 @@ class FreeAnchor3DHead(Anchor3DHead):
return
losses
return
losses
def
positive_bag_loss
(
self
,
matched_cls_prob
,
matched_box_prob
):
def
positive_bag_loss
(
self
,
matched_cls_prob
,
matched_box_prob
):
"""Generate positive bag loss
"""Generate positive bag loss
.
Args:
Args:
matched_cls_prob (torch.Tensor): Classification probability
matched_cls_prob (torch.Tensor): Classification probability
...
@@ -259,7 +259,7 @@ class FreeAnchor3DHead(Anchor3DHead):
...
@@ -259,7 +259,7 @@ class FreeAnchor3DHead(Anchor3DHead):
bag_prob
,
torch
.
ones_like
(
bag_prob
),
reduction
=
'none'
)
bag_prob
,
torch
.
ones_like
(
bag_prob
),
reduction
=
'none'
)
def
negative_bag_loss
(
self
,
cls_prob
,
box_prob
):
def
negative_bag_loss
(
self
,
cls_prob
,
box_prob
):
"""Generate negative bag loss
"""Generate negative bag loss
.
Args:
Args:
cls_prob (torch.Tensor): Classification probability
cls_prob (torch.Tensor): Classification probability
...
...
mmdet3d/models/dense_heads/parta2_rpn_head.py
View file @
80b39bd0
...
@@ -11,7 +11,7 @@ from .anchor3d_head import Anchor3DHead
...
@@ -11,7 +11,7 @@ from .anchor3d_head import Anchor3DHead
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
class
PartA2RPNHead
(
Anchor3DHead
):
class
PartA2RPNHead
(
Anchor3DHead
):
"""RPN head for PartA2
"""RPN head for PartA2
.
Note:
Note:
The main difference between the PartA2 RPN head and the Anchor3DHead
The main difference between the PartA2 RPN head and the Anchor3DHead
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment