Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
76e351a7
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "7eb16d1f35b224a92abbb4a6bb7bbb612a48a9ed"
Unverified
Commit
76e351a7
authored
May 01, 2022
by
Wenwei Zhang
Committed by
GitHub
May 01, 2022
Browse files
Release v1.0.0rc2
parents
5111eda8
4422eaab
Changes
137
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
553 additions
and
47 deletions
+553
-47
mmdet3d/__init__.py
mmdet3d/__init__.py
+1
-1
mmdet3d/apis/train.py
mmdet3d/apis/train.py
+283
-3
mmdet3d/core/bbox/structures/base_box3d.py
mmdet3d/core/bbox/structures/base_box3d.py
+9
-12
mmdet3d/core/evaluation/scannet_utils/__init__.py
mmdet3d/core/evaluation/scannet_utils/__init__.py
+4
-0
mmdet3d/core/evaluation/waymo_utils/__init__.py
mmdet3d/core/evaluation/waymo_utils/__init__.py
+4
-0
mmdet3d/core/post_processing/__init__.py
mmdet3d/core/post_processing/__init__.py
+4
-2
mmdet3d/core/post_processing/box3d_nms.py
mmdet3d/core/post_processing/box3d_nms.py
+65
-4
mmdet3d/core/post_processing/merge_augs.py
mmdet3d/core/post_processing/merge_augs.py
+3
-4
mmdet3d/datasets/__init__.py
mmdet3d/datasets/__init__.py
+2
-2
mmdet3d/datasets/builder.py
mmdet3d/datasets/builder.py
+2
-1
mmdet3d/datasets/custom_3d.py
mmdet3d/datasets/custom_3d.py
+60
-3
mmdet3d/datasets/custom_3d_seg.py
mmdet3d/datasets/custom_3d_seg.py
+1
-1
mmdet3d/datasets/kitti2d_dataset.py
mmdet3d/datasets/kitti2d_dataset.py
+2
-1
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+1
-1
mmdet3d/datasets/kitti_mono_dataset.py
mmdet3d/datasets/kitti_mono_dataset.py
+5
-1
mmdet3d/datasets/lyft_dataset.py
mmdet3d/datasets/lyft_dataset.py
+1
-1
mmdet3d/datasets/nuscenes_dataset.py
mmdet3d/datasets/nuscenes_dataset.py
+3
-6
mmdet3d/datasets/nuscenes_mono_dataset.py
mmdet3d/datasets/nuscenes_mono_dataset.py
+43
-3
mmdet3d/datasets/pipelines/__init__.py
mmdet3d/datasets/pipelines/__init__.py
+1
-1
mmdet3d/datasets/pipelines/compose.py
mmdet3d/datasets/pipelines/compose.py
+59
-0
No files found.
mmdet3d/__init__.py
View file @
76e351a7
...
@@ -19,7 +19,7 @@ def digit_version(version_str):
...
@@ -19,7 +19,7 @@ def digit_version(version_str):
mmcv_minimum_version
=
'1.4.8'
mmcv_minimum_version
=
'1.4.8'
mmcv_maximum_version
=
'1.
5
.0'
mmcv_maximum_version
=
'1.
6
.0'
mmcv_version
=
digit_version
(
mmcv
.
__version__
)
mmcv_version
=
digit_version
(
mmcv
.
__version__
)
...
...
mmdet3d/apis/train.py
View file @
76e351a7
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
warnings
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.runner
import
get_dist_info
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
HOOKS
,
DistSamplerSeedHook
,
EpochBasedRunner
,
Fp16OptimizerHook
,
OptimizerHook
,
build_optimizer
,
build_runner
,
get_dist_info
)
from
mmcv.utils
import
build_from_cfg
from
torch
import
distributed
as
dist
from
torch
import
distributed
as
dist
from
mmdet.apis
import
train_detector
from
mmdet3d.datasets
import
build_dataset
from
mmseg.apis
import
train_segmentor
from
mmdet3d.utils
import
find_latest_checkpoint
from
mmdet.core
import
DistEvalHook
as
MMDET_DistEvalHook
from
mmdet.core
import
EvalHook
as
MMDET_EvalHook
from
mmdet.datasets
import
build_dataloader
as
build_mmdet_dataloader
from
mmdet.datasets
import
replace_ImageToTensor
from
mmdet.utils
import
get_root_logger
as
get_mmdet_root_logger
from
mmseg.core
import
DistEvalHook
as
MMSEG_DistEvalHook
from
mmseg.core
import
EvalHook
as
MMSEG_EvalHook
from
mmseg.datasets
import
build_dataloader
as
build_mmseg_dataloader
from
mmseg.utils
import
get_root_logger
as
get_mmseg_root_logger
def
init_random_seed
(
seed
=
None
,
device
=
'cuda'
):
def
init_random_seed
(
seed
=
None
,
device
=
'cuda'
):
...
@@ -39,6 +55,270 @@ def init_random_seed(seed=None, device='cuda'):
...
@@ -39,6 +55,270 @@ def init_random_seed(seed=None, device='cuda'):
return
random_num
.
item
()
return
random_num
.
item
()
def
set_random_seed
(
seed
,
deterministic
=
False
):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
if
deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
def
train_segmentor
(
model
,
dataset
,
cfg
,
distributed
=
False
,
validate
=
False
,
timestamp
=
None
,
meta
=
None
):
"""Launch segmentor training."""
logger
=
get_mmseg_root_logger
(
cfg
.
log_level
)
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
data_loaders
=
[
build_mmseg_dataloader
(
ds
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
# cfg.gpus will be ignored if distributed
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
seed
=
cfg
.
seed
,
drop_last
=
True
)
for
ds
in
dataset
]
# put model on gpus
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDataParallel
(
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
# build runner
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
if
cfg
.
get
(
'runner'
)
is
None
:
cfg
.
runner
=
{
'type'
:
'IterBasedRunner'
,
'max_iters'
:
cfg
.
total_iters
}
warnings
.
warn
(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.'
,
UserWarning
)
runner
=
build_runner
(
cfg
.
runner
,
default_args
=
dict
(
model
=
model
,
batch_processor
=
None
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
meta
=
meta
))
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
cfg
.
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
))
# an ugly walkaround to make the .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
# register eval hooks
if
validate
:
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataloader
=
build_mmseg_dataloader
(
val_dataset
,
samples_per_gpu
=
1
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_cfg
[
'by_epoch'
]
=
cfg
.
runner
[
'type'
]
!=
'IterBasedRunner'
eval_hook
=
MMSEG_DistEvalHook
if
distributed
else
MMSEG_EvalHook
# In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
# priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
),
priority
=
'LOW'
)
# user-defined hooks
if
cfg
.
get
(
'custom_hooks'
,
None
):
custom_hooks
=
cfg
.
custom_hooks
assert
isinstance
(
custom_hooks
,
list
),
\
f
'custom_hooks expect list type, but got
{
type
(
custom_hooks
)
}
'
for
hook_cfg
in
cfg
.
custom_hooks
:
assert
isinstance
(
hook_cfg
,
dict
),
\
'Each item in custom_hooks expects dict type, but got '
\
f
'
{
type
(
hook_cfg
)
}
'
hook_cfg
=
hook_cfg
.
copy
()
priority
=
hook_cfg
.
pop
(
'priority'
,
'NORMAL'
)
hook
=
build_from_cfg
(
hook_cfg
,
HOOKS
)
runner
.
register_hook
(
hook
,
priority
=
priority
)
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
)
def
train_detector
(
model
,
dataset
,
cfg
,
distributed
=
False
,
validate
=
False
,
timestamp
=
None
,
meta
=
None
):
logger
=
get_mmdet_root_logger
(
log_level
=
cfg
.
log_level
)
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
if
'imgs_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
'"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead'
)
if
'samples_per_gpu'
in
cfg
.
data
:
logger
.
warning
(
f
'Got "imgs_per_gpu"=
{
cfg
.
data
.
imgs_per_gpu
}
and '
f
'"samples_per_gpu"=
{
cfg
.
data
.
samples_per_gpu
}
, "imgs_per_gpu"'
f
'=
{
cfg
.
data
.
imgs_per_gpu
}
is used in this experiments'
)
else
:
logger
.
warning
(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f
'
{
cfg
.
data
.
imgs_per_gpu
}
in this experiments'
)
cfg
.
data
.
samples_per_gpu
=
cfg
.
data
.
imgs_per_gpu
runner_type
=
'EpochBasedRunner'
if
'runner'
not
in
cfg
else
cfg
.
runner
[
'type'
]
data_loaders
=
[
build_mmdet_dataloader
(
ds
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
# `num_gpus` will be ignored if distributed
num_gpus
=
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
seed
=
cfg
.
seed
,
runner_type
=
runner_type
,
persistent_workers
=
cfg
.
data
.
get
(
'persistent_workers'
,
False
))
for
ds
in
dataset
]
# put model on gpus
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDataParallel
(
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
# build runner
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
if
'runner'
not
in
cfg
:
cfg
.
runner
=
{
'type'
:
'EpochBasedRunner'
,
'max_epochs'
:
cfg
.
total_epochs
}
warnings
.
warn
(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.'
,
UserWarning
)
else
:
if
'total_epochs'
in
cfg
:
assert
cfg
.
total_epochs
==
cfg
.
runner
.
max_epochs
runner
=
build_runner
(
cfg
.
runner
,
default_args
=
dict
(
model
=
model
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
meta
=
meta
))
# an ugly workaround to make .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
# fp16 setting
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
optimizer_config
=
Fp16OptimizerHook
(
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
elif
distributed
and
'type'
not
in
cfg
.
optimizer_config
:
optimizer_config
=
OptimizerHook
(
**
cfg
.
optimizer_config
)
else
:
optimizer_config
=
cfg
.
optimizer_config
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
),
custom_hooks_config
=
cfg
.
get
(
'custom_hooks'
,
None
))
if
distributed
:
if
isinstance
(
runner
,
EpochBasedRunner
):
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
validate
:
# Support batch_size > 1 in validation
val_samples_per_gpu
=
cfg
.
data
.
val
.
pop
(
'samples_per_gpu'
,
1
)
if
val_samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
val
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
val
.
pipeline
)
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataloader
=
build_mmdet_dataloader
(
val_dataset
,
samples_per_gpu
=
val_samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_cfg
[
'by_epoch'
]
=
cfg
.
runner
[
'type'
]
!=
'IterBasedRunner'
eval_hook
=
MMDET_DistEvalHook
if
distributed
else
MMDET_EvalHook
# In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
# priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
),
priority
=
'LOW'
)
resume_from
=
None
if
cfg
.
resume_from
is
None
and
cfg
.
get
(
'auto_resume'
):
resume_from
=
find_latest_checkpoint
(
cfg
.
work_dir
)
if
resume_from
is
not
None
:
cfg
.
resume_from
=
resume_from
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
)
def
train_model
(
model
,
def
train_model
(
model
,
dataset
,
dataset
,
cfg
,
cfg
,
...
...
mmdet3d/core/bbox/structures/base_box3d.py
View file @
76e351a7
...
@@ -4,10 +4,9 @@ from abc import abstractmethod
...
@@ -4,10 +4,9 @@ from abc import abstractmethod
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv._ext
import
iou3d_boxes_overlap_bev_forward
as
boxes_overlap_bev_gpu
from
mmcv.ops
import
box_iou_rotated
,
points_in_boxes_all
,
points_in_boxes_part
from
mmcv.ops
import
points_in_boxes_all
,
points_in_boxes_part
from
.utils
import
limit_period
,
xywhr2xyxyr
from
.utils
import
limit_period
class
BaseInstance3DBoxes
(
object
):
class
BaseInstance3DBoxes
(
object
):
...
@@ -447,7 +446,7 @@ class BaseInstance3DBoxes(object):
...
@@ -447,7 +446,7 @@ class BaseInstance3DBoxes(object):
mode (str, optional): Mode of iou calculation. Defaults to 'iou'.
mode (str, optional): Mode of iou calculation. Defaults to 'iou'.
Returns:
Returns:
torch.Tensor: Calculated
iou of boxes' height
s.
torch.Tensor: Calculated
3D overlaps of the boxe
s.
"""
"""
assert
isinstance
(
boxes1
,
BaseInstance3DBoxes
)
assert
isinstance
(
boxes1
,
BaseInstance3DBoxes
)
assert
isinstance
(
boxes2
,
BaseInstance3DBoxes
)
assert
isinstance
(
boxes2
,
BaseInstance3DBoxes
)
...
@@ -464,15 +463,13 @@ class BaseInstance3DBoxes(object):
...
@@ -464,15 +463,13 @@ class BaseInstance3DBoxes(object):
# height overlap
# height overlap
overlaps_h
=
cls
.
height_overlaps
(
boxes1
,
boxes2
)
overlaps_h
=
cls
.
height_overlaps
(
boxes1
,
boxes2
)
# obtain BEV boxes in XYXYR format
boxes1_bev
=
xywhr2xyxyr
(
boxes1
.
bev
)
boxes2_bev
=
xywhr2xyxyr
(
boxes2
.
bev
)
# bev overlap
# bev overlap
overlaps_bev
=
boxes1_bev
.
new_zeros
(
iou2d
=
box_iou_rotated
(
boxes1
.
bev
,
boxes2
.
bev
)
(
boxes1_bev
.
shape
[
0
],
boxes2_bev
.
shape
[
0
])).
cuda
()
# (N, M)
areas1
=
(
boxes1
.
bev
[:,
2
]
*
boxes1
.
bev
[:,
3
]).
unsqueeze
(
1
).
expand
(
boxes_overlap_bev_gpu
(
boxes1_bev
.
contiguous
().
cuda
(),
rows
,
cols
)
boxes2_bev
.
contiguous
().
cuda
(),
overlaps_bev
)
areas2
=
(
boxes2
.
bev
[:,
2
]
*
boxes2
.
bev
[:,
3
]).
unsqueeze
(
0
).
expand
(
rows
,
cols
)
overlaps_bev
=
iou2d
*
(
areas1
+
areas2
)
/
(
1
+
iou2d
)
# 3d overlaps
# 3d overlaps
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
...
...
mmdet3d/core/evaluation/scannet_utils/__init__.py
0 → 100644
View file @
76e351a7
# Copyright (c) OpenMMLab. All rights reserved.
from
.evaluate_semantic_instance
import
evaluate_matches
,
scannet_eval
__all__
=
[
'scannet_eval'
,
'evaluate_matches'
]
mmdet3d/core/evaluation/waymo_utils/__init__.py
0 → 100644
View file @
76e351a7
# Copyright (c) OpenMMLab. All rights reserved.
from
.prediction_kitti_to_waymo
import
KITTI2Waymo
__all__
=
[
'KITTI2Waymo'
]
mmdet3d/core/post_processing/__init__.py
View file @
76e351a7
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
from
mmdet.core.post_processing
import
(
merge_aug_bboxes
,
merge_aug_masks
,
from
mmdet.core.post_processing
import
(
merge_aug_bboxes
,
merge_aug_masks
,
merge_aug_proposals
,
merge_aug_scores
,
merge_aug_proposals
,
merge_aug_scores
,
multiclass_nms
)
multiclass_nms
)
from
.box3d_nms
import
aligned_3d_nms
,
box3d_multiclass_nms
,
circle_nms
from
.box3d_nms
import
(
aligned_3d_nms
,
box3d_multiclass_nms
,
circle_nms
,
nms_bev
,
nms_normal_bev
)
from
.merge_augs
import
merge_aug_bboxes_3d
from
.merge_augs
import
merge_aug_bboxes_3d
__all__
=
[
__all__
=
[
'multiclass_nms'
,
'merge_aug_proposals'
,
'merge_aug_bboxes'
,
'multiclass_nms'
,
'merge_aug_proposals'
,
'merge_aug_bboxes'
,
'merge_aug_scores'
,
'merge_aug_masks'
,
'box3d_multiclass_nms'
,
'merge_aug_scores'
,
'merge_aug_masks'
,
'box3d_multiclass_nms'
,
'aligned_3d_nms'
,
'merge_aug_bboxes_3d'
,
'circle_nms'
'aligned_3d_nms'
,
'merge_aug_bboxes_3d'
,
'circle_nms'
,
'nms_bev'
,
'nms_normal_bev'
]
]
mmdet3d/core/post_processing/box3d_nms.py
View file @
76e351a7
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.ops
import
nms
,
nms_rotated
from
mmcv.ops
import
nms_normal_bev
as
nms_normal_gpu
from
..bbox
import
xywhr2xyxyr
def
box3d_multiclass_nms
(
mlvl_bboxes
,
def
box3d_multiclass_nms
(
mlvl_bboxes
,
...
@@ -61,9 +62,9 @@ def box3d_multiclass_nms(mlvl_bboxes,
...
@@ -61,9 +62,9 @@ def box3d_multiclass_nms(mlvl_bboxes,
_bboxes_for_nms
=
mlvl_bboxes_for_nms
[
cls_inds
,
:]
_bboxes_for_nms
=
mlvl_bboxes_for_nms
[
cls_inds
,
:]
if
cfg
.
use_rotate_nms
:
if
cfg
.
use_rotate_nms
:
nms_func
=
nms_
gpu
nms_func
=
nms_
bev
else
:
else
:
nms_func
=
nms_normal_
gpu
nms_func
=
nms_normal_
bev
selected
=
nms_func
(
_bboxes_for_nms
,
_scores
,
cfg
.
nms_thr
)
selected
=
nms_func
(
_bboxes_for_nms
,
_scores
,
cfg
.
nms_thr
)
_mlvl_bboxes
=
mlvl_bboxes
[
cls_inds
,
:]
_mlvl_bboxes
=
mlvl_bboxes
[
cls_inds
,
:]
...
@@ -224,3 +225,63 @@ def circle_nms(dets, thresh, post_max_size=83):
...
@@ -224,3 +225,63 @@ def circle_nms(dets, thresh, post_max_size=83):
return
keep
[:
post_max_size
]
return
keep
[:
post_max_size
]
return
keep
return
keep
# This function duplicates functionality of mmcv.ops.iou_3d.nms_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated.
# Nms api will be unified in mmdetection3d one day.
def
nms_bev
(
boxes
,
scores
,
thresh
,
pre_max_size
=
None
,
post_max_size
=
None
):
"""NMS function GPU implementation (for BEV boxes). The overlap of two
boxes for IoU calculation is defined as the exact overlapping area of the
two boxes. In this function, one can also set ``pre_max_size`` and
``post_max_size``.
Args:
boxes (torch.Tensor): Input boxes with the shape of [N, 5]
([x1, y1, x2, y2, ry]).
scores (torch.Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS.
Default: None.
post_max_size (int, optional): Max size of boxes after NMS.
Default: None.
Returns:
torch.Tensor: Indexes after NMS.
"""
assert
boxes
.
size
(
1
)
==
5
,
'Input boxes shape should be [N, 5]'
order
=
scores
.
sort
(
0
,
descending
=
True
)[
1
]
if
pre_max_size
is
not
None
:
order
=
order
[:
pre_max_size
]
boxes
=
boxes
[
order
].
contiguous
()
# xyxyr -> back to xywhr
# note: better skip this step before nms_bev call in the future
boxes
=
torch
.
stack
(
((
boxes
[:,
0
]
+
boxes
[:,
2
])
/
2
,
(
boxes
[:,
1
]
+
boxes
[:,
3
])
/
2
,
boxes
[:,
2
]
-
boxes
[:,
0
],
boxes
[:,
3
]
-
boxes
[:,
1
],
boxes
[:,
4
]),
dim
=-
1
)
keep
=
nms_rotated
(
boxes
,
scores
,
thresh
)[
1
]
if
post_max_size
is
not
None
:
keep
=
keep
[:
post_max_size
]
return
keep
# This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms.
# Nms api will be unified in mmdetection3d one day.
def
nms_normal_bev
(
boxes
,
scores
,
thresh
):
"""Normal NMS function GPU implementation (for BEV boxes). The overlap of
two boxes for IoU calculation is defined as the exact overlapping area of
the two boxes WITH their yaw angle set to 0.
Args:
boxes (torch.Tensor): Input boxes with shape (N, 5).
scores (torch.Tensor): Scores of predicted boxes with shape (N).
thresh (float): Overlap threshold of NMS.
Returns:
torch.Tensor: Remaining indices with scores in descending order.
"""
assert
boxes
.
shape
[
1
]
==
5
,
'Input boxes shape should be [N, 5]'
return
nms
(
xywhr2xyxyr
(
boxes
)[:,
:
-
1
],
scores
,
thresh
)[
1
]
mmdet3d/core/post_processing/merge_augs.py
View file @
76e351a7
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.ops
import
nms_normal_bev
as
nms_normal_gpu
from
mmdet3d.core.post_processing
import
nms_bev
,
nms_normal_bev
from
..bbox
import
bbox3d2result
,
bbox3d_mapping_back
,
xywhr2xyxyr
from
..bbox
import
bbox3d2result
,
bbox3d_mapping_back
,
xywhr2xyxyr
...
@@ -52,9 +51,9 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):
...
@@ -52,9 +51,9 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):
# TODO: use a more elegent way to deal with nms
# TODO: use a more elegent way to deal with nms
if
test_cfg
.
use_rotate_nms
:
if
test_cfg
.
use_rotate_nms
:
nms_func
=
nms_
gpu
nms_func
=
nms_
bev
else
:
else
:
nms_func
=
nms_normal_
gpu
nms_func
=
nms_normal_
bev
merged_bboxes
=
[]
merged_bboxes
=
[]
merged_scores
=
[]
merged_scores
=
[]
...
...
mmdet3d/datasets/__init__.py
View file @
76e351a7
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
mmdet.datasets.builder
import
build_dataloader
from
mmdet.datasets.builder
import
build_dataloader
from
.builder
import
DATASETS
,
build_dataset
from
.builder
import
DATASETS
,
PIPELINES
,
build_dataset
from
.custom_3d
import
Custom3DDataset
from
.custom_3d
import
Custom3DDataset
from
.custom_3d_seg
import
Custom3DSegDataset
from
.custom_3d_seg
import
Custom3DSegDataset
from
.kitti_dataset
import
KittiDataset
from
.kitti_dataset
import
KittiDataset
...
@@ -41,5 +41,5 @@ __all__ = [
...
@@ -41,5 +41,5 @@ __all__ = [
'LoadPointsFromMultiSweeps'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'LoadPointsFromMultiSweeps'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'VoxelBasedPointSampler'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'VoxelBasedPointSampler'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'RandomJitterPoints'
,
'ObjectNameFilter'
,
'AffineResize'
,
'RandomJitterPoints'
,
'ObjectNameFilter'
,
'AffineResize'
,
'RandomShiftScale'
,
'LoadPointsFromDict'
'RandomShiftScale'
,
'LoadPointsFromDict'
,
'PIPELINES'
]
]
mmdet3d/datasets/builder.py
View file @
76e351a7
...
@@ -3,7 +3,6 @@ import platform
...
@@ -3,7 +3,6 @@ import platform
from
mmcv.utils
import
Registry
,
build_from_cfg
from
mmcv.utils
import
Registry
,
build_from_cfg
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets.builder
import
_concat_dataset
from
mmdet.datasets.builder
import
_concat_dataset
if
platform
.
system
()
!=
'Windows'
:
if
platform
.
system
()
!=
'Windows'
:
...
@@ -16,6 +15,8 @@ if platform.system() != 'Windows':
...
@@ -16,6 +15,8 @@ if platform.system() != 'Windows':
resource
.
setrlimit
(
resource
.
RLIMIT_NOFILE
,
(
soft_limit
,
hard_limit
))
resource
.
setrlimit
(
resource
.
RLIMIT_NOFILE
,
(
soft_limit
,
hard_limit
))
OBJECTSAMPLERS
=
Registry
(
'Object sampler'
)
OBJECTSAMPLERS
=
Registry
(
'Object sampler'
)
DATASETS
=
Registry
(
'dataset'
)
PIPELINES
=
Registry
(
'pipeline'
)
def
build_dataset
(
cfg
,
default_args
=
None
):
def
build_dataset
(
cfg
,
default_args
=
None
):
...
...
mmdet3d/datasets/custom_3d.py
View file @
76e351a7
...
@@ -7,8 +7,8 @@ import mmcv
...
@@ -7,8 +7,8 @@ import mmcv
import
numpy
as
np
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
mmdet.datasets
import
DATASETS
from
..core.bbox
import
get_box_type
from
..core.bbox
import
get_box_type
from
.builder
import
DATASETS
from
.pipelines
import
Compose
from
.pipelines
import
Compose
from
.utils
import
extract_result_dict
,
get_loading_pipeline
from
.utils
import
extract_result_dict
,
get_loading_pipeline
...
@@ -20,6 +20,23 @@ class Custom3DDataset(Dataset):
...
@@ -20,6 +20,23 @@ class Custom3DDataset(Dataset):
This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
dataset.
dataset.
.. code-block:: none
[
{'sample_idx':
'lidar_points': {'lidar_path': velodyne_path,
....
},
'annos': {'box_type_3d': (str) 'LiDAR/Camera/Depth'
'gt_bboxes_3d': <np.ndarray> (n, 7)
'gt_names': [list]
....
}
'calib': { .....}
'images': { .....}
}
]
Args:
Args:
data_root (str): Path of dataset root.
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file.
...
@@ -113,8 +130,9 @@ class Custom3DDataset(Dataset):
...
@@ -113,8 +130,9 @@ class Custom3DDataset(Dataset):
- ann_info (dict): Annotation info.
- ann_info (dict): Annotation info.
"""
"""
info
=
self
.
data_infos
[
index
]
info
=
self
.
data_infos
[
index
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
sample_idx
=
info
[
'sample_idx'
]
pts_filename
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_path'
])
pts_filename
=
osp
.
join
(
self
.
data_root
,
info
[
'lidar_points'
][
'lidar_path'
])
input_dict
=
dict
(
input_dict
=
dict
(
pts_filename
=
pts_filename
,
pts_filename
=
pts_filename
,
...
@@ -128,6 +146,45 @@ class Custom3DDataset(Dataset):
...
@@ -128,6 +146,45 @@ class Custom3DDataset(Dataset):
return
None
return
None
return
input_dict
return
input_dict
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths.
- gt_names (list[str]): Class names of ground truths.
"""
info
=
self
.
data_infos
[
index
]
gt_bboxes_3d
=
info
[
'annos'
][
'gt_bboxes_3d'
]
gt_names_3d
=
info
[
'annos'
][
'gt_names'
]
gt_labels_3d
=
[]
for
cat
in
gt_names_3d
:
if
cat
in
self
.
CLASSES
:
gt_labels_3d
.
append
(
self
.
CLASSES
.
index
(
cat
))
else
:
gt_labels_3d
.
append
(
-
1
)
gt_labels_3d
=
np
.
array
(
gt_labels_3d
)
# Obtain original box 3d type in info file
ori_box_type_3d
=
info
[
'annos'
][
'box_type_3d'
]
ori_box_type_3d
,
_
=
get_box_type
(
ori_box_type_3d
)
# turn original box type to target box type
gt_bboxes_3d
=
ori_box_type_3d
(
gt_bboxes_3d
,
box_dim
=
gt_bboxes_3d
.
shape
[
-
1
],
origin
=
(
0.5
,
0.5
,
0.5
)).
convert_to
(
self
.
box_mode_3d
)
anns_results
=
dict
(
gt_bboxes_3d
=
gt_bboxes_3d
,
gt_labels_3d
=
gt_labels_3d
)
return
anns_results
def
pre_pipeline
(
self
,
results
):
def
pre_pipeline
(
self
,
results
):
"""Initialization before data preparation.
"""Initialization before data preparation.
...
...
mmdet3d/datasets/custom_3d_seg.py
View file @
76e351a7
...
@@ -7,8 +7,8 @@ import mmcv
...
@@ -7,8 +7,8 @@ import mmcv
import
numpy
as
np
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
mmdet.datasets
import
DATASETS
from
mmseg.datasets
import
DATASETS
as
SEG_DATASETS
from
mmseg.datasets
import
DATASETS
as
SEG_DATASETS
from
.builder
import
DATASETS
from
.pipelines
import
Compose
from
.pipelines
import
Compose
from
.utils
import
extract_result_dict
,
get_loading_pipeline
from
.utils
import
extract_result_dict
,
get_loading_pipeline
...
...
mmdet3d/datasets/kitti2d_dataset.py
View file @
76e351a7
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
from
mmdet.datasets
import
DATASETS
,
CustomDataset
from
mmdet.datasets
import
CustomDataset
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
...
...
mmdet3d/datasets/kitti_dataset.py
View file @
76e351a7
...
@@ -9,10 +9,10 @@ import numpy as np
...
@@ -9,10 +9,10 @@ import numpy as np
import
torch
import
torch
from
mmcv.utils
import
print_log
from
mmcv.utils
import
print_log
from
mmdet.datasets
import
DATASETS
from
..core
import
show_multi_modality_result
,
show_result
from
..core
import
show_multi_modality_result
,
show_result
from
..core.bbox
import
(
Box3DMode
,
CameraInstance3DBoxes
,
Coord3DMode
,
from
..core.bbox
import
(
Box3DMode
,
CameraInstance3DBoxes
,
Coord3DMode
,
LiDARInstance3DBoxes
,
points_cam2img
)
LiDARInstance3DBoxes
,
points_cam2img
)
from
.builder
import
DATASETS
from
.custom_3d
import
Custom3DDataset
from
.custom_3d
import
Custom3DDataset
from
.pipelines
import
Compose
from
.pipelines
import
Compose
...
...
mmdet3d/datasets/kitti_mono_dataset.py
View file @
76e351a7
...
@@ -8,8 +8,8 @@ import numpy as np
...
@@ -8,8 +8,8 @@ import numpy as np
import
torch
import
torch
from
mmcv.utils
import
print_log
from
mmcv.utils
import
print_log
from
mmdet.datasets
import
DATASETS
from
..core.bbox
import
Box3DMode
,
CameraInstance3DBoxes
,
points_cam2img
from
..core.bbox
import
Box3DMode
,
CameraInstance3DBoxes
,
points_cam2img
from
.builder
import
DATASETS
from
.nuscenes_mono_dataset
import
NuScenesMonoDataset
from
.nuscenes_mono_dataset
import
NuScenesMonoDataset
...
@@ -35,6 +35,8 @@ class KittiMonoDataset(NuScenesMonoDataset):
...
@@ -35,6 +35,8 @@ class KittiMonoDataset(NuScenesMonoDataset):
def
__init__
(
self
,
def
__init__
(
self
,
data_root
,
data_root
,
info_file
,
info_file
,
ann_file
,
pipeline
,
load_interval
=
1
,
load_interval
=
1
,
with_velocity
=
False
,
with_velocity
=
False
,
eval_version
=
None
,
eval_version
=
None
,
...
@@ -42,6 +44,8 @@ class KittiMonoDataset(NuScenesMonoDataset):
...
@@ -42,6 +44,8 @@ class KittiMonoDataset(NuScenesMonoDataset):
**
kwargs
):
**
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
pipeline
=
pipeline
,
load_interval
=
load_interval
,
load_interval
=
load_interval
,
with_velocity
=
with_velocity
,
with_velocity
=
with_velocity
,
eval_version
=
eval_version
,
eval_version
=
eval_version
,
...
...
mmdet3d/datasets/lyft_dataset.py
View file @
76e351a7
...
@@ -11,9 +11,9 @@ from lyft_dataset_sdk.utils.data_classes import Box as LyftBox
...
@@ -11,9 +11,9 @@ from lyft_dataset_sdk.utils.data_classes import Box as LyftBox
from
pyquaternion
import
Quaternion
from
pyquaternion
import
Quaternion
from
mmdet3d.core.evaluation.lyft_eval
import
lyft_eval
from
mmdet3d.core.evaluation.lyft_eval
import
lyft_eval
from
mmdet.datasets
import
DATASETS
from
..core
import
show_result
from
..core
import
show_result
from
..core.bbox
import
Box3DMode
,
Coord3DMode
,
LiDARInstance3DBoxes
from
..core.bbox
import
Box3DMode
,
Coord3DMode
,
LiDARInstance3DBoxes
from
.builder
import
DATASETS
from
.custom_3d
import
Custom3DDataset
from
.custom_3d
import
Custom3DDataset
from
.pipelines
import
Compose
from
.pipelines
import
Compose
...
...
mmdet3d/datasets/nuscenes_dataset.py
View file @
76e351a7
...
@@ -7,9 +7,9 @@ import numpy as np
...
@@ -7,9 +7,9 @@ import numpy as np
import
pyquaternion
import
pyquaternion
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
mmdet.datasets
import
DATASETS
from
..core
import
show_result
from
..core
import
show_result
from
..core.bbox
import
Box3DMode
,
Coord3DMode
,
LiDARInstance3DBoxes
from
..core.bbox
import
Box3DMode
,
Coord3DMode
,
LiDARInstance3DBoxes
from
.builder
import
DATASETS
from
.custom_3d
import
Custom3DDataset
from
.custom_3d
import
Custom3DDataset
from
.pipelines
import
Compose
from
.pipelines
import
Compose
...
@@ -125,8 +125,7 @@ class NuScenesDataset(Custom3DDataset):
...
@@ -125,8 +125,7 @@ class NuScenesDataset(Custom3DDataset):
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
,
test_mode
=
False
,
eval_version
=
'detection_cvpr_2019'
,
eval_version
=
'detection_cvpr_2019'
,
use_valid_flag
=
False
,
use_valid_flag
=
False
):
**
kwargs
):
self
.
load_interval
=
load_interval
self
.
load_interval
=
load_interval
self
.
use_valid_flag
=
use_valid_flag
self
.
use_valid_flag
=
use_valid_flag
super
().
__init__
(
super
().
__init__
(
...
@@ -137,8 +136,7 @@ class NuScenesDataset(Custom3DDataset):
...
@@ -137,8 +136,7 @@ class NuScenesDataset(Custom3DDataset):
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
,
test_mode
=
test_mode
)
**
kwargs
)
self
.
with_velocity
=
with_velocity
self
.
with_velocity
=
with_velocity
self
.
eval_version
=
eval_version
self
.
eval_version
=
eval_version
...
@@ -186,7 +184,6 @@ class NuScenesDataset(Custom3DDataset):
...
@@ -186,7 +184,6 @@ class NuScenesDataset(Custom3DDataset):
Returns:
Returns:
list[dict]: List of annotations sorted by timestamps.
list[dict]: List of annotations sorted by timestamps.
"""
"""
# loading data from a file-like object needs file format
data
=
mmcv
.
load
(
ann_file
,
file_format
=
'pkl'
)
data
=
mmcv
.
load
(
ann_file
,
file_format
=
'pkl'
)
data_infos
=
list
(
sorted
(
data
[
'infos'
],
key
=
lambda
e
:
e
[
'timestamp'
]))
data_infos
=
list
(
sorted
(
data
[
'infos'
],
key
=
lambda
e
:
e
[
'timestamp'
]))
data_infos
=
data_infos
[::
self
.
load_interval
]
data_infos
=
data_infos
[::
self
.
load_interval
]
...
...
mmdet3d/datasets/nuscenes_mono_dataset.py
View file @
76e351a7
...
@@ -11,9 +11,10 @@ import torch
...
@@ -11,9 +11,10 @@ import torch
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
mmdet3d.core
import
bbox3d2result
,
box3d_multiclass_nms
,
xywhr2xyxyr
from
mmdet3d.core
import
bbox3d2result
,
box3d_multiclass_nms
,
xywhr2xyxyr
from
mmdet.datasets
import
DATASETS
,
CocoDataset
from
mmdet.datasets
import
CocoDataset
from
..core
import
show_multi_modality_result
from
..core
import
show_multi_modality_result
from
..core.bbox
import
CameraInstance3DBoxes
,
get_box_type
from
..core.bbox
import
CameraInstance3DBoxes
,
get_box_type
from
.builder
import
DATASETS
from
.pipelines
import
Compose
from
.pipelines
import
Compose
from
.utils
import
extract_result_dict
,
get_loading_pipeline
from
.utils
import
extract_result_dict
,
get_loading_pipeline
...
@@ -76,6 +77,8 @@ class NuScenesMonoDataset(CocoDataset):
...
@@ -76,6 +77,8 @@ class NuScenesMonoDataset(CocoDataset):
def
__init__
(
self
,
def
__init__
(
self
,
data_root
,
data_root
,
ann_file
,
pipeline
,
load_interval
=
1
,
load_interval
=
1
,
with_velocity
=
True
,
with_velocity
=
True
,
modality
=
None
,
modality
=
None
,
...
@@ -83,9 +86,46 @@ class NuScenesMonoDataset(CocoDataset):
...
@@ -83,9 +86,46 @@ class NuScenesMonoDataset(CocoDataset):
eval_version
=
'detection_cvpr_2019'
,
eval_version
=
'detection_cvpr_2019'
,
use_valid_flag
=
False
,
use_valid_flag
=
False
,
version
=
'v1.0-trainval'
,
version
=
'v1.0-trainval'
,
**
kwargs
):
classes
=
None
,
super
().
__init__
(
**
kwargs
)
img_prefix
=
''
,
seg_prefix
=
None
,
proposal_file
=
None
,
test_mode
=
False
,
filter_empty_gt
=
True
,
file_client_args
=
dict
(
backend
=
'disk'
)):
self
.
ann_file
=
ann_file
self
.
data_root
=
data_root
self
.
data_root
=
data_root
self
.
img_prefix
=
img_prefix
self
.
seg_prefix
=
seg_prefix
self
.
proposal_file
=
proposal_file
self
.
test_mode
=
test_mode
self
.
filter_empty_gt
=
filter_empty_gt
self
.
CLASSES
=
self
.
get_classes
(
classes
)
self
.
file_client
=
mmcv
.
FileClient
(
**
file_client_args
)
# load annotations (and proposals)
with
self
.
file_client
.
get_local_path
(
self
.
ann_file
)
as
local_path
:
self
.
data_infos
=
self
.
load_annotations
(
local_path
)
if
self
.
proposal_file
is
not
None
:
with
self
.
file_client
.
get_local_path
(
self
.
proposal_file
)
as
local_path
:
self
.
proposals
=
self
.
load_proposals
(
local_path
)
else
:
self
.
proposals
=
None
# filter images too small and containing no annotations
if
not
test_mode
:
valid_inds
=
self
.
_filter_imgs
()
self
.
data_infos
=
[
self
.
data_infos
[
i
]
for
i
in
valid_inds
]
if
self
.
proposals
is
not
None
:
self
.
proposals
=
[
self
.
proposals
[
i
]
for
i
in
valid_inds
]
# set group flag for the sampler
self
.
_set_group_flag
()
# processing pipeline
self
.
pipeline
=
Compose
(
pipeline
)
self
.
load_interval
=
load_interval
self
.
load_interval
=
load_interval
self
.
with_velocity
=
with_velocity
self
.
with_velocity
=
with_velocity
self
.
modality
=
modality
self
.
modality
=
modality
...
...
mmdet3d/datasets/pipelines/__init__.py
View file @
76e351a7
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
mmdet.datasets.pipelines
import
Compose
from
.compose
import
Compose
from
.dbsampler
import
DataBaseSampler
from
.dbsampler
import
DataBaseSampler
from
.formating
import
Collect3D
,
DefaultFormatBundle
,
DefaultFormatBundle3D
from
.formating
import
Collect3D
,
DefaultFormatBundle
,
DefaultFormatBundle3D
from
.loading
import
(
LoadAnnotations3D
,
LoadImageFromFileMono3D
,
from
.loading
import
(
LoadAnnotations3D
,
LoadImageFromFileMono3D
,
...
...
mmdet3d/datasets/pipelines/compose.py
0 → 100644
View file @
76e351a7
# Copyright (c) OpenMMLab. All rights reserved.
import
collections
from
mmcv.utils
import
build_from_cfg
from
mmdet.datasets.builder
import
PIPELINES
as
MMDET_PIPELINES
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
Compose
:
"""Compose multiple transforms sequentially. The pipeline registry of
mmdet3d separates with mmdet, however, sometimes we may need to use mmdet's
pipeline. So the class is rewritten to be able to use pipelines from both
mmdet3d and mmdet.
Args:
transforms (Sequence[dict | callable]): Sequence of transform object or
config dict to be composed.
"""
def
__init__
(
self
,
transforms
):
assert
isinstance
(
transforms
,
collections
.
abc
.
Sequence
)
self
.
transforms
=
[]
for
transform
in
transforms
:
if
isinstance
(
transform
,
dict
):
if
transform
[
'type'
]
in
PIPELINES
.
_module_dict
.
keys
():
transform
=
build_from_cfg
(
transform
,
PIPELINES
)
else
:
transform
=
build_from_cfg
(
transform
,
MMDET_PIPELINES
)
self
.
transforms
.
append
(
transform
)
elif
callable
(
transform
):
self
.
transforms
.
append
(
transform
)
else
:
raise
TypeError
(
'transform must be callable or a dict'
)
def
__call__
(
self
,
data
):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for
t
in
self
.
transforms
:
data
=
t
(
data
)
if
data
is
None
:
return
None
return
data
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
'('
for
t
in
self
.
transforms
:
format_string
+=
'
\n
'
format_string
+=
f
'
{
t
}
'
format_string
+=
'
\n
)'
return
format_string
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment