Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
333536f6
Unverified
Commit
333536f6
authored
Apr 06, 2022
by
Wenwei Zhang
Committed by
GitHub
Apr 06, 2022
Browse files
Release v1.0.0rc1
parents
9c7270d0
f747daab
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
243 additions
and
65 deletions
+243
-65
mmdet3d/datasets/lyft_dataset.py
mmdet3d/datasets/lyft_dataset.py
+6
-3
mmdet3d/datasets/nuscenes_dataset.py
mmdet3d/datasets/nuscenes_dataset.py
+6
-3
mmdet3d/datasets/pipelines/dbsampler.py
mmdet3d/datasets/pipelines/dbsampler.py
+16
-2
mmdet3d/datasets/pipelines/loading.py
mmdet3d/datasets/pipelines/loading.py
+4
-4
mmdet3d/datasets/pipelines/transforms_3d.py
mmdet3d/datasets/pipelines/transforms_3d.py
+4
-10
mmdet3d/datasets/s3dis_dataset.py
mmdet3d/datasets/s3dis_dataset.py
+14
-8
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+159
-7
mmdet3d/datasets/sunrgbd_dataset.py
mmdet3d/datasets/sunrgbd_dataset.py
+6
-4
mmdet3d/datasets/waymo_dataset.py
mmdet3d/datasets/waymo_dataset.py
+4
-2
mmdet3d/models/dense_heads/centerpoint_head.py
mmdet3d/models/dense_heads/centerpoint_head.py
+1
-1
mmdet3d/models/dense_heads/groupfree3d_head.py
mmdet3d/models/dense_heads/groupfree3d_head.py
+2
-1
mmdet3d/models/dense_heads/parta2_rpn_head.py
mmdet3d/models/dense_heads/parta2_rpn_head.py
+2
-1
mmdet3d/models/dense_heads/point_rpn_head.py
mmdet3d/models/dense_heads/point_rpn_head.py
+2
-1
mmdet3d/models/dense_heads/vote_head.py
mmdet3d/models/dense_heads/vote_head.py
+2
-1
mmdet3d/models/detectors/mvx_two_stage.py
mmdet3d/models/detectors/mvx_two_stage.py
+1
-1
mmdet3d/models/detectors/parta2.py
mmdet3d/models/detectors/parta2.py
+1
-1
mmdet3d/models/detectors/voxelnet.py
mmdet3d/models/detectors/voxelnet.py
+1
-1
mmdet3d/models/middle_encoders/pillar_scatter.py
mmdet3d/models/middle_encoders/pillar_scatter.py
+2
-2
mmdet3d/models/middle_encoders/sparse_encoder.py
mmdet3d/models/middle_encoders/sparse_encoder.py
+5
-6
mmdet3d/models/middle_encoders/sparse_unet.py
mmdet3d/models/middle_encoders/sparse_unet.py
+5
-6
No files found.
mmdet3d/datasets/lyft_dataset.py
View file @
333536f6
...
@@ -86,7 +86,8 @@ class LyftDataset(Custom3DDataset):
...
@@ -86,7 +86,8 @@ class LyftDataset(Custom3DDataset):
modality
=
None
,
modality
=
None
,
box_type_3d
=
'LiDAR'
,
box_type_3d
=
'LiDAR'
,
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
):
test_mode
=
False
,
**
kwargs
):
self
.
load_interval
=
load_interval
self
.
load_interval
=
load_interval
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
...
@@ -96,7 +97,8 @@ class LyftDataset(Custom3DDataset):
...
@@ -96,7 +97,8 @@ class LyftDataset(Custom3DDataset):
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
)
test_mode
=
test_mode
,
**
kwargs
)
if
self
.
modality
is
None
:
if
self
.
modality
is
None
:
self
.
modality
=
dict
(
self
.
modality
=
dict
(
...
@@ -116,7 +118,8 @@ class LyftDataset(Custom3DDataset):
...
@@ -116,7 +118,8 @@ class LyftDataset(Custom3DDataset):
Returns:
Returns:
list[dict]: List of annotations sorted by timestamps.
list[dict]: List of annotations sorted by timestamps.
"""
"""
data
=
mmcv
.
load
(
ann_file
)
# loading data from a file-like object needs file format
data
=
mmcv
.
load
(
ann_file
,
file_format
=
'pkl'
)
data_infos
=
list
(
sorted
(
data
[
'infos'
],
key
=
lambda
e
:
e
[
'timestamp'
]))
data_infos
=
list
(
sorted
(
data
[
'infos'
],
key
=
lambda
e
:
e
[
'timestamp'
]))
data_infos
=
data_infos
[::
self
.
load_interval
]
data_infos
=
data_infos
[::
self
.
load_interval
]
self
.
metadata
=
data
[
'metadata'
]
self
.
metadata
=
data
[
'metadata'
]
...
...
mmdet3d/datasets/nuscenes_dataset.py
View file @
333536f6
...
@@ -125,7 +125,8 @@ class NuScenesDataset(Custom3DDataset):
...
@@ -125,7 +125,8 @@ class NuScenesDataset(Custom3DDataset):
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
,
test_mode
=
False
,
eval_version
=
'detection_cvpr_2019'
,
eval_version
=
'detection_cvpr_2019'
,
use_valid_flag
=
False
):
use_valid_flag
=
False
,
**
kwargs
):
self
.
load_interval
=
load_interval
self
.
load_interval
=
load_interval
self
.
use_valid_flag
=
use_valid_flag
self
.
use_valid_flag
=
use_valid_flag
super
().
__init__
(
super
().
__init__
(
...
@@ -136,7 +137,8 @@ class NuScenesDataset(Custom3DDataset):
...
@@ -136,7 +137,8 @@ class NuScenesDataset(Custom3DDataset):
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
)
test_mode
=
test_mode
,
**
kwargs
)
self
.
with_velocity
=
with_velocity
self
.
with_velocity
=
with_velocity
self
.
eval_version
=
eval_version
self
.
eval_version
=
eval_version
...
@@ -184,7 +186,8 @@ class NuScenesDataset(Custom3DDataset):
...
@@ -184,7 +186,8 @@ class NuScenesDataset(Custom3DDataset):
Returns:
Returns:
list[dict]: List of annotations sorted by timestamps.
list[dict]: List of annotations sorted by timestamps.
"""
"""
data
=
mmcv
.
load
(
ann_file
)
# loading data from a file-like object needs file format
data
=
mmcv
.
load
(
ann_file
,
file_format
=
'pkl'
)
data_infos
=
list
(
sorted
(
data
[
'infos'
],
key
=
lambda
e
:
e
[
'timestamp'
]))
data_infos
=
list
(
sorted
(
data
[
'infos'
],
key
=
lambda
e
:
e
[
'timestamp'
]))
data_infos
=
data_infos
[::
self
.
load_interval
]
data_infos
=
data_infos
[::
self
.
load_interval
]
self
.
metadata
=
data
[
'metadata'
]
self
.
metadata
=
data
[
'metadata'
]
...
...
mmdet3d/datasets/pipelines/dbsampler.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
copy
import
os
import
os
import
warnings
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
...
@@ -104,7 +105,8 @@ class DataBaseSampler(object):
...
@@ -104,7 +105,8 @@ class DataBaseSampler(object):
type
=
'LoadPointsFromFile'
,
type
=
'LoadPointsFromFile'
,
coord_type
=
'LIDAR'
,
coord_type
=
'LIDAR'
,
load_dim
=
4
,
load_dim
=
4
,
use_dim
=
[
0
,
1
,
2
,
3
])):
use_dim
=
[
0
,
1
,
2
,
3
]),
file_client_args
=
dict
(
backend
=
'disk'
)):
super
().
__init__
()
super
().
__init__
()
self
.
data_root
=
data_root
self
.
data_root
=
data_root
self
.
info_path
=
info_path
self
.
info_path
=
info_path
...
@@ -114,8 +116,20 @@ class DataBaseSampler(object):
...
@@ -114,8 +116,20 @@ class DataBaseSampler(object):
self
.
cat2label
=
{
name
:
i
for
i
,
name
in
enumerate
(
classes
)}
self
.
cat2label
=
{
name
:
i
for
i
,
name
in
enumerate
(
classes
)}
self
.
label2cat
=
{
i
:
name
for
i
,
name
in
enumerate
(
classes
)}
self
.
label2cat
=
{
i
:
name
for
i
,
name
in
enumerate
(
classes
)}
self
.
points_loader
=
mmcv
.
build_from_cfg
(
points_loader
,
PIPELINES
)
self
.
points_loader
=
mmcv
.
build_from_cfg
(
points_loader
,
PIPELINES
)
self
.
file_client
=
mmcv
.
FileClient
(
**
file_client_args
)
db_infos
=
mmcv
.
load
(
info_path
)
# load data base infos
if
hasattr
(
self
.
file_client
,
'get_local_path'
):
with
self
.
file_client
.
get_local_path
(
info_path
)
as
local_path
:
# loading data from a file-like object needs file format
db_infos
=
mmcv
.
load
(
open
(
local_path
,
'rb'
),
file_format
=
'pkl'
)
else
:
warnings
.
warn
(
'The used MMCV version does not have get_local_path. '
f
'We treat the
{
info_path
}
as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.'
)
db_infos
=
mmcv
.
load
(
info_path
)
# filter database infos
# filter database infos
from
mmdet3d.utils
import
get_root_logger
from
mmdet3d.utils
import
get_root_logger
...
...
mmdet3d/datasets/pipelines/loading.py
View file @
333536f6
...
@@ -518,7 +518,7 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -518,7 +518,7 @@ class LoadAnnotations3D(LoadAnnotations):
with_seg
=
False
,
with_seg
=
False
,
with_bbox_depth
=
False
,
with_bbox_depth
=
False
,
poly2mask
=
True
,
poly2mask
=
True
,
seg_3d_dtype
=
'
int
'
,
seg_3d_dtype
=
np
.
int
64
,
file_client_args
=
dict
(
backend
=
'disk'
)):
file_client_args
=
dict
(
backend
=
'disk'
)):
super
().
__init__
(
super
().
__init__
(
with_bbox
,
with_bbox
,
...
@@ -600,11 +600,11 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -600,11 +600,11 @@ class LoadAnnotations3D(LoadAnnotations):
self
.
file_client
=
mmcv
.
FileClient
(
**
self
.
file_client_args
)
self
.
file_client
=
mmcv
.
FileClient
(
**
self
.
file_client_args
)
try
:
try
:
mask_bytes
=
self
.
file_client
.
get
(
pts_instance_mask_path
)
mask_bytes
=
self
.
file_client
.
get
(
pts_instance_mask_path
)
pts_instance_mask
=
np
.
frombuffer
(
mask_bytes
,
dtype
=
np
.
int
)
pts_instance_mask
=
np
.
frombuffer
(
mask_bytes
,
dtype
=
np
.
int
64
)
except
ConnectionError
:
except
ConnectionError
:
mmcv
.
check_file_exist
(
pts_instance_mask_path
)
mmcv
.
check_file_exist
(
pts_instance_mask_path
)
pts_instance_mask
=
np
.
fromfile
(
pts_instance_mask
=
np
.
fromfile
(
pts_instance_mask_path
,
dtype
=
np
.
long
)
pts_instance_mask_path
,
dtype
=
np
.
int64
)
results
[
'pts_instance_mask'
]
=
pts_instance_mask
results
[
'pts_instance_mask'
]
=
pts_instance_mask
results
[
'pts_mask_fields'
].
append
(
'pts_instance_mask'
)
results
[
'pts_mask_fields'
].
append
(
'pts_instance_mask'
)
...
@@ -631,7 +631,7 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -631,7 +631,7 @@ class LoadAnnotations3D(LoadAnnotations):
except
ConnectionError
:
except
ConnectionError
:
mmcv
.
check_file_exist
(
pts_semantic_mask_path
)
mmcv
.
check_file_exist
(
pts_semantic_mask_path
)
pts_semantic_mask
=
np
.
fromfile
(
pts_semantic_mask
=
np
.
fromfile
(
pts_semantic_mask_path
,
dtype
=
np
.
long
)
pts_semantic_mask_path
,
dtype
=
np
.
int64
)
results
[
'pts_semantic_mask'
]
=
pts_semantic_mask
results
[
'pts_semantic_mask'
]
=
pts_semantic_mask
results
[
'pts_seg_fields'
].
append
(
'pts_semantic_mask'
)
results
[
'pts_seg_fields'
].
append
(
'pts_semantic_mask'
)
...
...
mmdet3d/datasets/pipelines/transforms_3d.py
View file @
333536f6
...
@@ -356,7 +356,7 @@ class ObjectSample(object):
...
@@ -356,7 +356,7 @@ class ObjectSample(object):
input_dict
[
'img'
]
=
sampled_dict
[
'img'
]
input_dict
[
'img'
]
=
sampled_dict
[
'img'
]
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
input_dict
[
'gt_labels_3d'
]
=
gt_labels_3d
.
astype
(
np
.
long
)
input_dict
[
'gt_labels_3d'
]
=
gt_labels_3d
.
astype
(
np
.
int64
)
input_dict
[
'points'
]
=
points
input_dict
[
'points'
]
=
points
return
input_dict
return
input_dict
...
@@ -907,9 +907,9 @@ class PointSample(object):
...
@@ -907,9 +907,9 @@ class PointSample(object):
point_range
=
range
(
len
(
points
))
point_range
=
range
(
len
(
points
))
if
sample_range
is
not
None
and
not
replace
:
if
sample_range
is
not
None
and
not
replace
:
# Only sampling the near points when len(points) >= num_samples
# Only sampling the near points when len(points) >= num_samples
d
epth
=
np
.
linalg
.
norm
(
points
.
tensor
,
axis
=
1
)
d
ist
=
np
.
linalg
.
norm
(
points
.
tensor
,
axis
=
1
)
far_inds
=
np
.
where
(
d
epth
>=
sample_range
)[
0
]
far_inds
=
np
.
where
(
d
ist
>=
sample_range
)[
0
]
near_inds
=
np
.
where
(
d
epth
<
sample_range
)[
0
]
near_inds
=
np
.
where
(
d
ist
<
sample_range
)[
0
]
# in case there are too many far points
# in case there are too many far points
if
len
(
far_inds
)
>
num_samples
:
if
len
(
far_inds
)
>
num_samples
:
far_inds
=
np
.
random
.
choice
(
far_inds
=
np
.
random
.
choice
(
...
@@ -936,12 +936,6 @@ class PointSample(object):
...
@@ -936,12 +936,6 @@ class PointSample(object):
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
points
=
results
[
'points'
]
points
=
results
[
'points'
]
# Points in Camera coord can provide the depth information.
# TODO: Need to support distance-based sampling for other coord system.
if
self
.
sample_range
is
not
None
:
from
mmdet3d.core.points
import
CameraPoints
assert
isinstance
(
points
,
CameraPoints
),
\
'Sampling based on distance is only applicable for CAM coord'
points
,
choices
=
self
.
_points_random_sampling
(
points
,
choices
=
self
.
_points_random_sampling
(
points
,
points
,
self
.
num_points
,
self
.
num_points
,
...
...
mmdet3d/datasets/s3dis_dataset.py
View file @
333536f6
...
@@ -54,7 +54,8 @@ class S3DISDataset(Custom3DDataset):
...
@@ -54,7 +54,8 @@ class S3DISDataset(Custom3DDataset):
modality
=
None
,
modality
=
None
,
box_type_3d
=
'Depth'
,
box_type_3d
=
'Depth'
,
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
):
test_mode
=
False
,
*
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -63,7 +64,8 @@ class S3DISDataset(Custom3DDataset):
...
@@ -63,7 +64,8 @@ class S3DISDataset(Custom3DDataset):
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
)
test_mode
=
test_mode
,
*
kwargs
)
def
get_ann_info
(
self
,
index
):
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
"""Get annotation info according to the given index.
...
@@ -85,10 +87,10 @@ class S3DISDataset(Custom3DDataset):
...
@@ -85,10 +87,10 @@ class S3DISDataset(Custom3DDataset):
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
np
.
float32
)
# k, 6
np
.
float32
)
# k, 6
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
long
)
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
int64
)
else
:
else
:
gt_bboxes_3d
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
gt_bboxes_3d
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
long
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
int64
)
# to target box structure
# to target box structure
gt_bboxes_3d
=
DepthInstance3DBoxes
(
gt_bboxes_3d
=
DepthInstance3DBoxes
(
...
@@ -205,7 +207,8 @@ class _S3DISSegDataset(Custom3DSegDataset):
...
@@ -205,7 +207,8 @@ class _S3DISSegDataset(Custom3DSegDataset):
modality
=
None
,
modality
=
None
,
test_mode
=
False
,
test_mode
=
False
,
ignore_index
=
None
,
ignore_index
=
None
,
scene_idxs
=
None
):
scene_idxs
=
None
,
**
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
...
@@ -216,7 +219,8 @@ class _S3DISSegDataset(Custom3DSegDataset):
...
@@ -216,7 +219,8 @@ class _S3DISSegDataset(Custom3DSegDataset):
modality
=
modality
,
modality
=
modality
,
test_mode
=
test_mode
,
test_mode
=
test_mode
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
)
scene_idxs
=
scene_idxs
,
**
kwargs
)
def
get_ann_info
(
self
,
index
):
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
"""Get annotation info according to the given index.
...
@@ -347,7 +351,8 @@ class S3DISSegDataset(_S3DISSegDataset):
...
@@ -347,7 +351,8 @@ class S3DISSegDataset(_S3DISSegDataset):
modality
=
None
,
modality
=
None
,
test_mode
=
False
,
test_mode
=
False
,
ignore_index
=
None
,
ignore_index
=
None
,
scene_idxs
=
None
):
scene_idxs
=
None
,
**
kwargs
):
# make sure that ann_files and scene_idxs have same length
# make sure that ann_files and scene_idxs have same length
ann_files
=
self
.
_check_ann_files
(
ann_files
)
ann_files
=
self
.
_check_ann_files
(
ann_files
)
...
@@ -363,7 +368,8 @@ class S3DISSegDataset(_S3DISSegDataset):
...
@@ -363,7 +368,8 @@ class S3DISSegDataset(_S3DISSegDataset):
modality
=
modality
,
modality
=
modality
,
test_mode
=
test_mode
,
test_mode
=
test_mode
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
[
0
])
scene_idxs
=
scene_idxs
[
0
],
**
kwargs
)
datasets
=
[
datasets
=
[
_S3DISSegDataset
(
_S3DISSegDataset
(
...
...
mmdet3d/datasets/scannet_dataset.py
View file @
333536f6
...
@@ -5,7 +5,7 @@ from os import path as osp
...
@@ -5,7 +5,7 @@ from os import path as osp
import
numpy
as
np
import
numpy
as
np
from
mmdet3d.core
import
show_result
,
show_seg_result
from
mmdet3d.core
import
instance_seg_eval
,
show_result
,
show_seg_result
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets
import
DATASETS
from
mmseg.datasets
import
DATASETS
as
SEG_DATASETS
from
mmseg.datasets
import
DATASETS
as
SEG_DATASETS
...
@@ -58,7 +58,8 @@ class ScanNetDataset(Custom3DDataset):
...
@@ -58,7 +58,8 @@ class ScanNetDataset(Custom3DDataset):
modality
=
dict
(
use_camera
=
False
,
use_depth
=
True
),
modality
=
dict
(
use_camera
=
False
,
use_depth
=
True
),
box_type_3d
=
'Depth'
,
box_type_3d
=
'Depth'
,
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
):
test_mode
=
False
,
**
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -67,7 +68,8 @@ class ScanNetDataset(Custom3DDataset):
...
@@ -67,7 +68,8 @@ class ScanNetDataset(Custom3DDataset):
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
)
test_mode
=
test_mode
,
**
kwargs
)
assert
'use_camera'
in
self
.
modality
and
\
assert
'use_camera'
in
self
.
modality
and
\
'use_depth'
in
self
.
modality
'use_depth'
in
self
.
modality
assert
self
.
modality
[
'use_camera'
]
or
self
.
modality
[
'use_depth'
]
assert
self
.
modality
[
'use_camera'
]
or
self
.
modality
[
'use_depth'
]
...
@@ -143,10 +145,10 @@ class ScanNetDataset(Custom3DDataset):
...
@@ -143,10 +145,10 @@ class ScanNetDataset(Custom3DDataset):
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
np
.
float32
)
# k, 6
np
.
float32
)
# k, 6
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
long
)
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
int64
)
else
:
else
:
gt_bboxes_3d
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
gt_bboxes_3d
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
long
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
int64
)
# to target box structure
# to target box structure
gt_bboxes_3d
=
DepthInstance3DBoxes
(
gt_bboxes_3d
=
DepthInstance3DBoxes
(
...
@@ -322,7 +324,8 @@ class ScanNetSegDataset(Custom3DSegDataset):
...
@@ -322,7 +324,8 @@ class ScanNetSegDataset(Custom3DSegDataset):
modality
=
None
,
modality
=
None
,
test_mode
=
False
,
test_mode
=
False
,
ignore_index
=
None
,
ignore_index
=
None
,
scene_idxs
=
None
):
scene_idxs
=
None
,
**
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
...
@@ -333,7 +336,8 @@ class ScanNetSegDataset(Custom3DSegDataset):
...
@@ -333,7 +336,8 @@ class ScanNetSegDataset(Custom3DSegDataset):
modality
=
modality
,
modality
=
modality
,
test_mode
=
test_mode
,
test_mode
=
test_mode
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
)
scene_idxs
=
scene_idxs
,
**
kwargs
)
def
get_ann_info
(
self
,
index
):
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
"""Get annotation info according to the given index.
...
@@ -460,3 +464,151 @@ class ScanNetSegDataset(Custom3DSegDataset):
...
@@ -460,3 +464,151 @@ class ScanNetSegDataset(Custom3DSegDataset):
outputs
.
append
(
dict
(
seg_mask
=
pred_label
))
outputs
.
append
(
dict
(
seg_mask
=
pred_label
))
return
outputs
,
tmp_dir
return
outputs
,
tmp_dir
@
DATASETS
.
register_module
()
@
SEG_DATASETS
.
register_module
()
class
ScanNetInstanceSegDataset
(
Custom3DSegDataset
):
CLASSES
=
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
)
VALID_CLASS_IDS
=
(
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
33
,
34
,
36
,
39
)
ALL_CLASS_IDS
=
tuple
(
range
(
41
))
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
- pts_instance_mask_path (str): Path of instance masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
data_infos
[
index
]
pts_instance_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_instance_mask_path'
])
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_semantic_mask_path'
])
anns_results
=
dict
(
pts_instance_mask_path
=
pts_instance_mask_path
,
pts_semantic_mask_path
=
pts_semantic_mask_path
)
return
anns_results
def
get_classes_and_palette
(
self
,
classes
=
None
,
palette
=
None
):
"""Get class names of current dataset. Palette is simply ignored for
instance segmentation.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Defaults to None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Defaults to None.
"""
if
classes
is
not
None
:
return
classes
,
None
return
self
.
CLASSES
,
None
def
_build_default_pipeline
(
self
):
"""Build the default pipeline for this dataset."""
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
use_color
=
True
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
True
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
,
valid_cat_ids
=
self
.
VALID_CLASS_IDS
,
max_cat_id
=
40
),
dict
(
type
=
'DefaultFormatBundle3D'
,
with_label
=
False
,
class_names
=
self
.
CLASSES
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
,
'pts_instance_mask'
])
]
return
Compose
(
pipeline
)
def
evaluate
(
self
,
results
,
metric
=
None
,
options
=
None
,
logger
=
None
,
show
=
False
,
out_dir
=
None
,
pipeline
=
None
):
"""Evaluation in instance segmentation protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
options (dict, optional): options for instance_seg_eval.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Defaults to None.
show (bool, optional): Whether to visualize.
Defaults to False.
out_dir (str, optional): Path to save the visualization results.
Defaults to None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns:
dict: Evaluation results.
"""
assert
isinstance
(
results
,
list
),
f
'Expect results to be list, got
{
type
(
results
)
}
.'
assert
len
(
results
)
>
0
,
'Expect length of results > 0.'
assert
len
(
results
)
==
len
(
self
.
data_infos
)
assert
isinstance
(
results
[
0
],
dict
),
f
'Expect elements in results to be dict, got
{
type
(
results
[
0
])
}
.'
load_pipeline
=
self
.
_get_pipeline
(
pipeline
)
pred_instance_masks
=
[
result
[
'instance_mask'
]
for
result
in
results
]
pred_instance_labels
=
[
result
[
'instance_label'
]
for
result
in
results
]
pred_instance_scores
=
[
result
[
'instance_score'
]
for
result
in
results
]
gt_semantic_masks
,
gt_instance_masks
=
zip
(
*
[
self
.
_extract_data
(
index
=
i
,
pipeline
=
load_pipeline
,
key
=
[
'pts_semantic_mask'
,
'pts_instance_mask'
],
load_annos
=
True
)
for
i
in
range
(
len
(
self
.
data_infos
))
])
ret_dict
=
instance_seg_eval
(
gt_semantic_masks
,
gt_instance_masks
,
pred_instance_masks
,
pred_instance_labels
,
pred_instance_scores
,
valid_class_ids
=
self
.
VALID_CLASS_IDS
,
class_labels
=
self
.
CLASSES
,
options
=
options
,
logger
=
logger
)
if
show
:
raise
NotImplementedError
(
'show is not implemented for now'
)
return
ret_dict
mmdet3d/datasets/sunrgbd_dataset.py
View file @
333536f6
...
@@ -54,7 +54,8 @@ class SUNRGBDDataset(Custom3DDataset):
...
@@ -54,7 +54,8 @@ class SUNRGBDDataset(Custom3DDataset):
modality
=
dict
(
use_camera
=
True
,
use_lidar
=
True
),
modality
=
dict
(
use_camera
=
True
,
use_lidar
=
True
),
box_type_3d
=
'Depth'
,
box_type_3d
=
'Depth'
,
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
):
test_mode
=
False
,
**
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -63,7 +64,8 @@ class SUNRGBDDataset(Custom3DDataset):
...
@@ -63,7 +64,8 @@ class SUNRGBDDataset(Custom3DDataset):
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
)
test_mode
=
test_mode
,
**
kwargs
)
assert
'use_camera'
in
self
.
modality
and
\
assert
'use_camera'
in
self
.
modality
and
\
'use_lidar'
in
self
.
modality
'use_lidar'
in
self
.
modality
assert
self
.
modality
[
'use_camera'
]
or
self
.
modality
[
'use_lidar'
]
assert
self
.
modality
[
'use_camera'
]
or
self
.
modality
[
'use_lidar'
]
...
@@ -137,10 +139,10 @@ class SUNRGBDDataset(Custom3DDataset):
...
@@ -137,10 +139,10 @@ class SUNRGBDDataset(Custom3DDataset):
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
np
.
float32
)
# k, 6
np
.
float32
)
# k, 6
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
long
)
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
int64
)
else
:
else
:
gt_bboxes_3d
=
np
.
zeros
((
0
,
7
),
dtype
=
np
.
float32
)
gt_bboxes_3d
=
np
.
zeros
((
0
,
7
),
dtype
=
np
.
float32
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
long
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
int64
)
# to target box structure
# to target box structure
gt_bboxes_3d
=
DepthInstance3DBoxes
(
gt_bboxes_3d
=
DepthInstance3DBoxes
(
...
...
mmdet3d/datasets/waymo_dataset.py
View file @
333536f6
...
@@ -66,7 +66,8 @@ class WaymoDataset(KittiDataset):
...
@@ -66,7 +66,8 @@ class WaymoDataset(KittiDataset):
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
,
test_mode
=
False
,
load_interval
=
1
,
load_interval
=
1
,
pcd_limit_range
=
[
-
85
,
-
85
,
-
5
,
85
,
85
,
5
]):
pcd_limit_range
=
[
-
85
,
-
85
,
-
5
,
85
,
85
,
5
],
**
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -78,7 +79,8 @@ class WaymoDataset(KittiDataset):
...
@@ -78,7 +79,8 @@ class WaymoDataset(KittiDataset):
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
,
test_mode
=
test_mode
,
pcd_limit_range
=
pcd_limit_range
)
pcd_limit_range
=
pcd_limit_range
,
**
kwargs
)
# to load a subset, just set the load_interval in the dataset config
# to load a subset, just set the load_interval in the dataset config
self
.
data_infos
=
self
.
data_infos
[::
load_interval
]
self
.
data_infos
=
self
.
data_infos
[::
load_interval
]
...
...
mmdet3d/models/dense_heads/centerpoint_head.py
View file @
333536f6
...
@@ -3,6 +3,7 @@ import copy
...
@@ -3,6 +3,7 @@ import copy
import
torch
import
torch
from
mmcv.cnn
import
ConvModule
,
build_conv_layer
from
mmcv.cnn
import
ConvModule
,
build_conv_layer
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.runner
import
BaseModule
,
force_fp32
from
mmcv.runner
import
BaseModule
,
force_fp32
from
torch
import
nn
from
torch
import
nn
...
@@ -11,7 +12,6 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
...
@@ -11,7 +12,6 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
from
mmdet3d.models
import
builder
from
mmdet3d.models
import
builder
from
mmdet3d.models.builder
import
HEADS
,
build_loss
from
mmdet3d.models.builder
import
HEADS
,
build_loss
from
mmdet3d.models.utils
import
clip_sigmoid
from
mmdet3d.models.utils
import
clip_sigmoid
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.core
import
build_bbox_coder
,
multi_apply
...
...
mmdet3d/models/dense_heads/groupfree3d_head.py
View file @
333536f6
...
@@ -7,13 +7,14 @@ from mmcv import ConfigDict
...
@@ -7,13 +7,14 @@ from mmcv import ConfigDict
from
mmcv.cnn
import
ConvModule
,
xavier_init
from
mmcv.cnn
import
ConvModule
,
xavier_init
from
mmcv.cnn.bricks.transformer
import
(
build_positional_encoding
,
from
mmcv.cnn.bricks.transformer
import
(
build_positional_encoding
,
build_transformer_layer
)
build_transformer_layer
)
from
mmcv.ops
import
PointsSampler
as
Points_Sampler
from
mmcv.ops
import
gather_points
from
mmcv.runner
import
BaseModule
,
force_fp32
from
mmcv.runner
import
BaseModule
,
force_fp32
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.core.post_processing
import
aligned_3d_nms
from
mmdet3d.core.post_processing
import
aligned_3d_nms
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.ops
import
Points_Sampler
,
gather_points
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
.base_conv_bbox_head
import
BaseConvBboxHead
from
.base_conv_bbox_head
import
BaseConvBboxHead
...
...
mmdet3d/models/dense_heads/parta2_rpn_head.py
View file @
333536f6
...
@@ -3,10 +3,11 @@ from __future__ import division
...
@@ -3,10 +3,11 @@ from __future__ import division
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.ops
import
nms_normal_bev
as
nms_normal_gpu
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
force_fp32
from
mmdet3d.core
import
limit_period
,
xywhr2xyxyr
from
mmdet3d.core
import
limit_period
,
xywhr2xyxyr
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
.anchor3d_head
import
Anchor3DHead
from
.anchor3d_head
import
Anchor3DHead
...
...
mmdet3d/models/dense_heads/point_rpn_head.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.ops
import
nms_normal_bev
as
nms_normal_gpu
from
mmcv.runner
import
BaseModule
,
force_fp32
from
mmcv.runner
import
BaseModule
,
force_fp32
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
mmdet3d.core.bbox.structures
import
(
DepthInstance3DBoxes
,
from
mmdet3d.core.bbox.structures
import
(
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
)
LiDARInstance3DBoxes
)
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.models
import
HEADS
,
build_loss
from
mmdet.models
import
HEADS
,
build_loss
...
...
mmdet3d/models/dense_heads/vote_head.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.ops
import
furthest_point_sample
from
mmcv.runner
import
BaseModule
,
force_fp32
from
mmcv.runner
import
BaseModule
,
force_fp32
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
...
@@ -8,7 +9,7 @@ from mmdet3d.core.post_processing import aligned_3d_nms
...
@@ -8,7 +9,7 @@ from mmdet3d.core.post_processing import aligned_3d_nms
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.models.losses
import
chamfer_distance
from
mmdet3d.models.losses
import
chamfer_distance
from
mmdet3d.models.model_utils
import
VoteModule
from
mmdet3d.models.model_utils
import
VoteModule
from
mmdet3d.ops
import
build_sa_module
,
furthest_point_sample
from
mmdet3d.ops
import
build_sa_module
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
.base_conv_bbox_head
import
BaseConvBboxHead
from
.base_conv_bbox_head
import
BaseConvBboxHead
...
...
mmdet3d/models/detectors/mvx_two_stage.py
View file @
333536f6
...
@@ -4,13 +4,13 @@ from os import path as osp
...
@@ -4,13 +4,13 @@ from os import path as osp
import
mmcv
import
mmcv
import
torch
import
torch
from
mmcv.ops
import
Voxelization
from
mmcv.parallel
import
DataContainer
as
DC
from
mmcv.parallel
import
DataContainer
as
DC
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
force_fp32
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.core
import
(
Box3DMode
,
Coord3DMode
,
bbox3d2result
,
from
mmdet3d.core
import
(
Box3DMode
,
Coord3DMode
,
bbox3d2result
,
merge_aug_bboxes_3d
,
show_result
)
merge_aug_bboxes_3d
,
show_result
)
from
mmdet3d.ops
import
Voxelization
from
mmdet.core
import
multi_apply
from
mmdet.core
import
multi_apply
from
mmdet.models
import
DETECTORS
from
mmdet.models
import
DETECTORS
from
..
import
builder
from
..
import
builder
...
...
mmdet3d/models/detectors/parta2.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmcv.ops
import
Voxelization
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.ops
import
Voxelization
from
mmdet.models
import
DETECTORS
from
mmdet.models
import
DETECTORS
from
..
import
builder
from
..
import
builder
from
.two_stage
import
TwoStage3DDetector
from
.two_stage
import
TwoStage3DDetector
...
...
mmdet3d/models/detectors/voxelnet.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmcv.ops
import
Voxelization
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
force_fp32
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.core
import
bbox3d2result
,
merge_aug_bboxes_3d
from
mmdet3d.core
import
bbox3d2result
,
merge_aug_bboxes_3d
from
mmdet3d.ops
import
Voxelization
from
mmdet.models
import
DETECTORS
from
mmdet.models
import
DETECTORS
from
..
import
builder
from
..
import
builder
from
.single_stage
import
SingleStage3DDetector
from
.single_stage
import
SingleStage3DDetector
...
...
mmdet3d/models/middle_encoders/pillar_scatter.py
View file @
333536f6
...
@@ -50,14 +50,14 @@ class PointPillarsScatter(nn.Module):
...
@@ -50,14 +50,14 @@ class PointPillarsScatter(nn.Module):
dtype
=
voxel_features
.
dtype
,
dtype
=
voxel_features
.
dtype
,
device
=
voxel_features
.
device
)
device
=
voxel_features
.
device
)
indices
=
coors
[:,
1
]
*
self
.
nx
+
coors
[:,
2
]
indices
=
coors
[:,
2
]
*
self
.
nx
+
coors
[:,
3
]
indices
=
indices
.
long
()
indices
=
indices
.
long
()
voxels
=
voxel_features
.
t
()
voxels
=
voxel_features
.
t
()
# Now scatter the blob back to the canvas.
# Now scatter the blob back to the canvas.
canvas
[:,
indices
]
=
voxels
canvas
[:,
indices
]
=
voxels
# Undo the column stacking to final 4-dim tensor
# Undo the column stacking to final 4-dim tensor
canvas
=
canvas
.
view
(
1
,
self
.
in_channels
,
self
.
ny
,
self
.
nx
)
canvas
=
canvas
.
view
(
1
,
self
.
in_channels
,
self
.
ny
,
self
.
nx
)
return
[
canvas
]
return
canvas
def
forward_batch
(
self
,
voxel_features
,
coors
,
batch_size
):
def
forward_batch
(
self
,
voxel_features
,
coors
,
batch_size
):
"""Scatter features of single sample.
"""Scatter features of single sample.
...
...
mmdet3d/models/middle_encoders/sparse_encoder.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.ops
import
SparseConvTensor
,
SparseSequential
from
mmcv.runner
import
auto_fp16
from
mmcv.runner
import
auto_fp16
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
mmdet3d.ops
import
SparseBasicBlock
,
make_sparse_convmodule
from
mmdet3d.ops
import
SparseBasicBlock
,
make_sparse_convmodule
from
mmdet3d.ops
import
spconv
as
spconv
from
..builder
import
MIDDLE_ENCODERS
from
..builder
import
MIDDLE_ENCODERS
...
@@ -109,9 +109,8 @@ class SparseEncoder(nn.Module):
...
@@ -109,9 +109,8 @@ class SparseEncoder(nn.Module):
dict: Backbone features.
dict: Backbone features.
"""
"""
coors
=
coors
.
int
()
coors
=
coors
.
int
()
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coors
,
input_sp_tensor
=
SparseConvTensor
(
voxel_features
,
coors
,
self
.
sparse_shape
,
self
.
sparse_shape
,
batch_size
)
batch_size
)
x
=
self
.
conv_input
(
input_sp_tensor
)
x
=
self
.
conv_input
(
input_sp_tensor
)
encode_features
=
[]
encode_features
=
[]
...
@@ -150,7 +149,7 @@ class SparseEncoder(nn.Module):
...
@@ -150,7 +149,7 @@ class SparseEncoder(nn.Module):
int: The number of encoder output channels.
int: The number of encoder output channels.
"""
"""
assert
block_type
in
[
'conv_module'
,
'basicblock'
]
assert
block_type
in
[
'conv_module'
,
'basicblock'
]
self
.
encoder_layers
=
spconv
.
SparseSequential
()
self
.
encoder_layers
=
SparseSequential
()
for
i
,
blocks
in
enumerate
(
self
.
encoder_channels
):
for
i
,
blocks
in
enumerate
(
self
.
encoder_channels
):
blocks_list
=
[]
blocks_list
=
[]
...
@@ -201,6 +200,6 @@ class SparseEncoder(nn.Module):
...
@@ -201,6 +200,6 @@ class SparseEncoder(nn.Module):
conv_type
=
'SubMConv3d'
))
conv_type
=
'SubMConv3d'
))
in_channels
=
out_channels
in_channels
=
out_channels
stage_name
=
f
'encoder_layer
{
i
+
1
}
'
stage_name
=
f
'encoder_layer
{
i
+
1
}
'
stage_layers
=
spconv
.
SparseSequential
(
*
blocks_list
)
stage_layers
=
SparseSequential
(
*
blocks_list
)
self
.
encoder_layers
.
add_module
(
stage_name
,
stage_layers
)
self
.
encoder_layers
.
add_module
(
stage_name
,
stage_layers
)
return
out_channels
return
out_channels
mmdet3d/models/middle_encoders/sparse_unet.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmcv.ops
import
SparseConvTensor
,
SparseSequential
from
mmcv.runner
import
BaseModule
,
auto_fp16
from
mmcv.runner
import
BaseModule
,
auto_fp16
from
mmdet3d.ops
import
SparseBasicBlock
,
make_sparse_convmodule
from
mmdet3d.ops
import
SparseBasicBlock
,
make_sparse_convmodule
from
mmdet3d.ops
import
spconv
as
spconv
from
..builder
import
MIDDLE_ENCODERS
from
..builder
import
MIDDLE_ENCODERS
...
@@ -108,9 +108,8 @@ class SparseUNet(BaseModule):
...
@@ -108,9 +108,8 @@ class SparseUNet(BaseModule):
dict[str, torch.Tensor]: Backbone features.
dict[str, torch.Tensor]: Backbone features.
"""
"""
coors
=
coors
.
int
()
coors
=
coors
.
int
()
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coors
,
input_sp_tensor
=
SparseConvTensor
(
voxel_features
,
coors
,
self
.
sparse_shape
,
self
.
sparse_shape
,
batch_size
)
batch_size
)
x
=
self
.
conv_input
(
input_sp_tensor
)
x
=
self
.
conv_input
(
input_sp_tensor
)
encode_features
=
[]
encode_features
=
[]
...
@@ -200,7 +199,7 @@ class SparseUNet(BaseModule):
...
@@ -200,7 +199,7 @@ class SparseUNet(BaseModule):
Returns:
Returns:
int: The number of encoder output channels.
int: The number of encoder output channels.
"""
"""
self
.
encoder_layers
=
spconv
.
SparseSequential
()
self
.
encoder_layers
=
SparseSequential
()
for
i
,
blocks
in
enumerate
(
self
.
encoder_channels
):
for
i
,
blocks
in
enumerate
(
self
.
encoder_channels
):
blocks_list
=
[]
blocks_list
=
[]
...
@@ -231,7 +230,7 @@ class SparseUNet(BaseModule):
...
@@ -231,7 +230,7 @@ class SparseUNet(BaseModule):
conv_type
=
'SubMConv3d'
))
conv_type
=
'SubMConv3d'
))
in_channels
=
out_channels
in_channels
=
out_channels
stage_name
=
f
'encoder_layer
{
i
+
1
}
'
stage_name
=
f
'encoder_layer
{
i
+
1
}
'
stage_layers
=
spconv
.
SparseSequential
(
*
blocks_list
)
stage_layers
=
SparseSequential
(
*
blocks_list
)
self
.
encoder_layers
.
add_module
(
stage_name
,
stage_layers
)
self
.
encoder_layers
.
add_module
(
stage_name
,
stage_layers
)
return
out_channels
return
out_channels
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment