Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
ff1e5b4e
Commit
ff1e5b4e
authored
Jun 16, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Model] Refactor basedetector and singestagedetector and add Det3DDataPreprocessor
parent
eca5a9f2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
486 additions
and
203 deletions
+486
-203
mmdet3d/core/utils/__init__.py
mmdet3d/core/utils/__init__.py
+8
-1
mmdet3d/core/utils/typing.py
mmdet3d/core/utils/typing.py
+31
-0
mmdet3d/models/data_preprocessors/__init__.py
mmdet3d/models/data_preprocessors/__init__.py
+4
-0
mmdet3d/models/data_preprocessors/data_preprocessor.py
mmdet3d/models/data_preprocessors/data_preprocessor.py
+184
-0
mmdet3d/models/detectors/base.py
mmdet3d/models/detectors/base.py
+60
-167
mmdet3d/models/detectors/single_stage.py
mmdet3d/models/detectors/single_stage.py
+107
-35
tests/test_models/test_preprocessors/test_data_preprocessor.py
.../test_models/test_preprocessors/test_data_preprocessor.py
+92
-0
No files found.
mmdet3d/core/utils/__init__.py
View file @
ff1e5b4e
...
...
@@ -2,9 +2,16 @@
from
.array_converter
import
ArrayConverter
,
array_converter
from
.gaussian
import
(
draw_heatmap_gaussian
,
ellip_gaussian2D
,
gaussian_2d
,
gaussian_radius
,
get_ellip_gaussian_2D
)
from
.typing
import
(
ConfigType
,
ForwardResults
,
InstanceList
,
MultiConfig
,
OptConfigType
,
OptInstanceList
,
OptMultiConfig
,
OptSampleList
,
OptSamplingResultList
,
SampleList
,
SamplingResultList
)
__all__
=
[
'gaussian_2d'
,
'gaussian_radius'
,
'draw_heatmap_gaussian'
,
'ArrayConverter'
,
'array_converter'
,
'ellip_gaussian2D'
,
'get_ellip_gaussian_2D'
'get_ellip_gaussian_2D'
,
'ConfigType'
,
'OptConfigType'
,
'MultiConfig'
,
'OptMultiConfig'
,
'InstanceList'
,
'OptInstanceList'
,
'SampleList'
,
'OptSampleList'
,
'SamplingResultList'
,
'ForwardResults'
,
'OptSamplingResultList'
]
mmdet3d/core/utils/typing.py
0 → 100644
View file @
ff1e5b4e
# Copyright (c) OpenMMLab. All rights reserved.
"""Collecting some commonly used type hint in MMDetection3D."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
mmengine.config
import
ConfigDict
from
mmengine.data
import
InstanceData
from
..bbox.samplers
import
SamplingResult
from
..data_structures
import
Det3DDataSample
# Type hint of config data
ConfigType
=
Union
[
ConfigDict
,
dict
]
OptConfigType
=
Optional
[
ConfigType
]
# Type hint of one or more config data
MultiConfig
=
Union
[
ConfigType
,
List
[
ConfigType
]]
OptMultiConfig
=
Optional
[
MultiConfig
]
InstanceList
=
List
[
InstanceData
]
OptInstanceList
=
Optional
[
InstanceList
]
SampleList
=
List
[
Det3DDataSample
]
OptSampleList
=
Optional
[
SampleList
]
SamplingResultList
=
List
[
SamplingResult
]
OptSamplingResultList
=
Optional
[
SamplingResultList
]
ForwardResults
=
Union
[
Dict
[
str
,
torch
.
Tensor
],
List
[
Det3DDataSample
],
Tuple
[
torch
.
Tensor
],
torch
.
Tensor
]
mmdet3d/models/data_preprocessors/__init__.py
0 → 100644
View file @
ff1e5b4e
# Copyright (c) OpenMMLab. All rights reserved.
from
.data_preprocessor
import
Det3DDataPreprocessor
__all__
=
[
'Det3DDataPreprocessor'
]
mmdet3d/models/data_preprocessors/data_preprocessor.py
0 → 100644
View file @
ff1e5b4e
# Copyright (c) OpenMMLab. All rights reserved.
from
numbers
import
Number
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
from
mmengine.data
import
BaseDataElement
from
mmengine.model
import
stack_batch
from
mmdet3d.registry
import
MODELS
from
mmdet.models
import
DetDataPreprocessor
@
MODELS
.
register_module
()
class
Det3DDataPreprocessor
(
DetDataPreprocessor
):
"""Points (Image) pre-processor for point clouds / multi-modality 3D
detection tasks.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad images in inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack images in inputs to batch_imgs.
- Convert images in inputs from bgr to rgb if the shape of input is
(3, H, W).
- Normalize images in inputs with defined std and mean.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
"""
def
__init__
(
self
,
mean
:
Sequence
[
Number
]
=
None
,
std
:
Sequence
[
Number
]
=
None
,
pad_size_divisor
:
int
=
1
,
pad_value
:
Union
[
float
,
int
]
=
0
,
pad_mask
:
bool
=
False
,
mask_pad_value
:
int
=
0
,
pad_seg
:
bool
=
False
,
seg_pad_value
:
int
=
255
,
bgr_to_rgb
:
bool
=
False
,
rgb_to_bgr
:
bool
=
False
,
batch_augments
:
Optional
[
List
[
dict
]]
=
None
):
super
().
__init__
(
mean
=
mean
,
std
=
std
,
pad_size_divisor
=
pad_size_divisor
,
pad_value
=
pad_value
,
pad_mask
=
pad_mask
,
mask_pad_value
=
mask_pad_value
,
pad_seg
=
pad_seg
,
seg_pad_value
=
seg_pad_value
,
bgr_to_rgb
=
bgr_to_rgb
,
rgb_to_bgr
=
rgb_to_bgr
,
batch_augments
=
batch_augments
)
def
forward
(
self
,
data
:
Sequence
[
dict
],
training
:
bool
=
False
)
->
Tuple
[
Dict
,
Optional
[
list
]]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[Dict, Optional[list]]: Data in the same format as the
model input.
"""
inputs_dict
,
batch_data_samples
=
self
.
collate_data
(
data
)
if
'points'
in
inputs_dict
[
0
].
keys
():
points
=
[
input
[
'points'
]
for
input
in
inputs_dict
]
else
:
raise
KeyError
(
"Model input dict needs to include the 'points' key."
)
if
'img'
in
inputs_dict
[
0
].
keys
():
imgs
=
[
input
[
'img'
]
for
input
in
inputs_dict
]
# channel transform
if
self
.
channel_conversion
:
imgs
=
[
_img
[[
2
,
1
,
0
],
...]
for
_img
in
imgs
]
# Normalization.
if
self
.
_enable_normalize
:
imgs
=
[(
_img
-
self
.
mean
)
/
self
.
std
for
_img
in
imgs
]
# Pad and stack Tensor.
batch_imgs
=
stack_batch
(
imgs
,
self
.
pad_size_divisor
,
self
.
pad_value
)
batch_pad_shape
=
self
.
_get_pad_shape
(
data
)
if
batch_data_samples
is
not
None
:
# NOTE the batched image size information may be useful, e.g.
batch_input_shape
=
tuple
(
batch_imgs
[
0
].
size
()[
-
2
:])
for
data_samples
,
pad_shape
in
zip
(
batch_data_samples
,
batch_pad_shape
):
data_samples
.
set_metainfo
({
'batch_input_shape'
:
batch_input_shape
,
'pad_shape'
:
pad_shape
})
if
self
.
pad_mask
:
self
.
pad_gt_masks
(
batch_data_samples
)
if
self
.
pad_seg
:
self
.
pad_gt_sem_seg
(
batch_data_samples
)
if
training
and
self
.
batch_augments
is
not
None
:
for
batch_aug
in
self
.
batch_augments
:
batch_imgs
,
batch_data_samples
=
batch_aug
(
batch_imgs
,
batch_data_samples
)
else
:
imgs
=
None
batch_inputs_dict
=
{
'points'
:
points
,
'imgs'
:
batch_imgs
if
imgs
is
not
None
else
None
}
return
batch_inputs_dict
,
batch_data_samples
def
collate_data
(
self
,
data
:
Sequence
[
dict
])
->
Tuple
[
List
[
dict
],
Optional
[
list
]]:
"""Collating and copying data to the target device.
Collates the data sampled from dataloader into a list of dict and
list of labels, and then copies tensor to the target device.
Args:
data (Sequence[dict]): Data sampled from dataloader.
Returns:
Tuple[List[Dict], Optional[list]]: Unstacked list of input
data dict and list of labels at target device.
"""
# rewrite `collate_data` since the inputs is a dict instead of
# image tensor.
inputs_dict
=
[{
k
:
v
.
to
(
self
.
_device
)
for
k
,
v
in
_data
[
'inputs'
].
items
()
}
for
_data
in
data
]
batch_data_samples
:
List
[
BaseDataElement
]
=
[]
# Model can get predictions without any data samples.
for
_data
in
data
:
if
'data_sample'
in
_data
:
batch_data_samples
.
append
(
_data
[
'data_sample'
])
# Move data from CPU to corresponding device.
batch_data_samples
=
[
data_sample
.
to
(
self
.
_device
)
for
data_sample
in
batch_data_samples
]
if
not
batch_data_samples
:
batch_data_samples
=
None
# type: ignore
return
inputs_dict
,
batch_data_samples
def
_get_pad_shape
(
self
,
data
:
Sequence
[
dict
])
->
List
[
tuple
]:
"""Get the pad_shape of each image based on data and
pad_size_divisor."""
# rewrite `_get_pad_shape` for obaining image inputs.
ori_inputs
=
[
_data
[
'inputs'
][
'img'
]
for
_data
in
data
]
batch_pad_shape
=
[]
for
ori_input
in
ori_inputs
:
pad_h
=
int
(
np
.
ceil
(
ori_input
.
shape
[
1
]
/
self
.
pad_size_divisor
))
*
self
.
pad_size_divisor
pad_w
=
int
(
np
.
ceil
(
ori_input
.
shape
[
2
]
/
self
.
pad_size_divisor
))
*
self
.
pad_size_divisor
batch_pad_shape
.
append
((
pad_h
,
pad_w
))
return
batch_pad_shape
mmdet3d/models/detectors/base.py
View file @
ff1e5b4e
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
from
mmengine.data
import
InstanceData
from
torch.optim
import
Optimizer
from
mmdet3d.core
import
Det3DDataSample
from
mmdet3d.core.utils
import
(
ForwardResults
,
InstanceList
,
OptConfigType
,
OptMultiConfig
,
OptSampleList
,
SampleList
)
from
mmdet3d.registry
import
MODELS
from
mmdet.core.utils
import
stack_batch
from
mmdet.models.detectors
import
BaseDetector
from
mmdet.models
import
BaseDetector
@
MODELS
.
register_module
()
...
...
@@ -16,191 +11,89 @@ class Base3DDetector(BaseDetector):
"""Base class for 3D detectors.
Args:
preprocess_cfg (dict, optional): Model preprocessing config
for processing the input data. it usually includes
``to_rgb``, ``pad_size_divisor``, ``pad_value``,
``mean`` and ``std``. Default to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (dict or ConfigDict, optional): the config to control the
initialization. Defaults to None.
"""
def
__init__
(
self
,
preprocess_cfg
:
Optional
[
dict
]
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
)
->
None
:
super
(
Base3DDetector
,
self
).
__init__
(
preprocess_cfg
=
preprocess_cfg
,
init_cfg
=
init_cfg
)
def
forward_simple_test
(
self
,
batch_inputs_dict
:
Dict
[
List
,
torch
.
Tensor
],
batch_data_samples
:
List
[
Det3DDataSample
],
**
kwargs
)
->
List
[
Det3DDataSample
]:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
data_processor
:
OptConfigType
=
None
,
init_cfg
:
OptMultiConfig
=
None
)
->
None
:
super
().
__init__
(
data_preprocessor
=
data_processor
,
init_cfg
=
init_cfg
)
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
def
forward
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
OptSampleList
=
None
,
mode
:
str
=
'tensor'
,
**
kwargs
)
->
ForwardResults
:
"""The unified entry for a forward process in both training and test.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
The method should accept three modes: "tensor", "predict" and "loss":
Returns:
list(obj:`Det3DDataSample`): Detection results of the
input images. Each DetDataSample usually contains
``pred_instances_3d`` or ``pred_panoptic_seg_3d`` or
``pred_sem_seg_3d``.
"""
batch_size
=
len
(
batch_data_samples
)
batch_input_metas
=
[]
if
batch_size
!=
len
(
batch_inputs_dict
[
'points'
]):
raise
ValueError
(
'num of augmentations ({}) != num of image meta ({})'
.
format
(
len
(
batch_inputs_dict
[
'points'
]),
len
(
batch_input_metas
)))
for
batch_index
in
range
(
batch_size
):
metainfo
=
batch_data_samples
[
batch_index
].
metainfo
batch_input_metas
.
append
(
metainfo
)
for
var
,
name
in
[(
batch_inputs_dict
[
'points'
],
'points'
),
(
batch_input_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
if
batch_size
==
1
:
return
self
.
simple_test
(
batch_inputs_dict
,
batch_input_metas
,
rescale
=
True
,
**
kwargs
)
else
:
return
self
.
aug_test
(
batch_inputs_dict
,
batch_input_metas
,
rescale
=
True
,
**
kwargs
)
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DetDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
def
forward
(
self
,
data
:
List
[
dict
],
optimizer
:
Optional
[
Union
[
Optimizer
,
dict
]]
=
None
,
return_loss
:
bool
=
False
,
**
kwargs
):
"""The iteration step during training and testing. This method defines
an iteration step during training and testing, except for the back
propagation and optimizer updating during training, which are done in
an optimizer scheduler.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
data (list[dict]): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer`, dict, Optional): The
optimizer of runner. This argument is unused and reserved.
Default to None.
return_loss (bool): Whether to return loss. In general,
it will be set to True during training and False
during testing. Default to False.
batch_inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
during training
dict: It should contain at least 3 keys: ``loss``,
``log_vars``, ``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model
is DDP, it means the batch size on each GPU), which is
used for averaging the logs.
during testing
list(obj:`Det3DDataSample`): Detection results of the
input samples. Each DetDataSample usually contains
``pred_instances_3d`` or ``pred_panoptic_seg_3d`` or
``pred_sem_seg_3d``.
"""
The return type depends on ``mode``.
batch_inputs_dict
,
batch_data_samples
=
self
.
preprocess_data
(
data
)
if
return_loss
:
losses
=
self
.
forward_train
(
batch_inputs_dict
,
batch_data_samples
,
**
kwargs
)
loss
,
log_vars
=
self
.
_parse_losses
(
losses
)
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
len
(
batch_data_samples
))
return
outputs
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if
mode
==
'loss'
:
return
self
.
loss
(
batch_inputs_dict
,
batch_data_samples
,
**
kwargs
)
elif
mode
==
'predict'
:
return
self
.
predict
(
batch_inputs_dict
,
batch_data_samples
,
**
kwargs
)
elif
mode
==
'tensor'
:
return
self
.
_forward
(
batch_inputs_dict
,
batch_data_samples
,
**
kwargs
)
else
:
return
self
.
forward_simple_test
(
batch_inputs_dict
,
batch_data_samples
,
**
kwargs
)
def
preprocess_data
(
self
,
data
:
List
[
dict
])
->
tuple
:
""" Process input data during training and simple testing phases.
Args:
data (list[dict]): The data to be processed, which
comes from dataloader.
Returns:
tuple: It should contain 2 item.
raise
RuntimeError
(
f
'Invalid mode "
{
mode
}
". '
'Only supports loss, predict and tensor mode'
)
- batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys
.
def
convert_to_datasample
(
self
,
results_list
:
InstanceList
)
->
SampleList
:
"""Convert results list to `Det3DDataSample`
.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample
.
Subclasses could override it to be compatible for some multi-modality
3D detectors
.
- batch_data_samples (list[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d` , `gt_instances`.
"""
batch_data_samples
=
[
data_
[
'data_sample'
].
to
(
self
.
device
)
for
data_
in
data
]
if
'points'
in
data
[
0
][
'inputs'
].
keys
():
points
=
[
data_
[
'inputs'
][
'points'
].
to
(
self
.
device
)
for
data_
in
data
]
else
:
raise
KeyError
(
"Model input dict needs to include the 'points' key."
)
if
'img'
in
data
[
0
][
'inputs'
].
keys
():
imgs
=
[
data_
[
'inputs'
][
'img'
].
to
(
self
.
device
)
for
data_
in
data
]
else
:
imgs
=
None
if
self
.
preprocess_cfg
is
None
:
batch_inputs_dict
=
{
'points'
:
points
,
'imgs'
:
stack_batch
(
imgs
).
float
()
if
imgs
is
not
None
else
None
}
return
batch_inputs_dict
,
batch_data_samples
if
self
.
to_rgb
and
imgs
[
0
].
size
(
0
)
==
3
:
imgs
=
[
_img
[[
2
,
1
,
0
],
...]
for
_img
in
imgs
]
imgs
=
[(
_img
-
self
.
pixel_mean
)
/
self
.
pixel_std
for
_img
in
imgs
]
batch_img
=
stack_batch
(
imgs
,
self
.
pad_size_divisor
,
self
.
pad_value
)
batch_inputs_dict
=
{
'points'
:
points
,
'imgs'
:
batch_img
}
return
batch_inputs_dict
,
batch_data_samples
def
postprocess_result
(
self
,
results_list
:
List
[
InstanceData
])
\
->
List
[
Det3DDataSample
]:
""" Convert results list to `Det3DDataSample`.
Args:
results_list (list[:obj:`InstanceData`]): Detection results of
each sample.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input
sample
. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d
d
`` usually
input. Each Det3DDataSample usually contain
s
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance
s
, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of
3D
bboxes, has a shape
(num_instances, ).
- bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
"""
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
out_results_list
=
[]
for
i
in
range
(
len
(
results_list
)):
result
=
Det3DDataSample
()
result
.
pred_instances_3d
=
results_list
[
i
]
results_list
[
i
]
=
result
return
results_list
def
show_results
(
self
,
data
,
result
,
out_dir
,
show
=
False
,
score_thr
=
None
):
# TODO
pass
out_results_list
.
append
(
result
)
return
out_results_list
mmdet3d/models/detectors/single_stage.py
View file @
ff1e5b4e
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Opt
ion
al
from
typing
import
List
,
Tuple
,
Un
ion
import
torch
from
mmdet3d.core.utils
import
(
ConfigType
,
OptConfigType
,
OptMultiConfig
,
OptSampleList
,
SampleList
)
from
mmdet3d.registry
import
MODELS
from
.base
import
Base3DDetector
...
...
@@ -11,7 +13,10 @@ from .base import Base3DDetector
class
SingleStage3DDetector
(
Base3DDetector
):
"""SingleStage3DDetector.
This class serves as a base class for single-stage 3D detectors.
This class serves as a base class for single-stage 3D detectors which
directly and densely predict 3D bounding boxes on the output features
of the backbone+neck.
Args:
backbone (dict): Config dict of detector's backbone.
...
...
@@ -21,21 +26,22 @@ class SingleStage3DDetector(Base3DDetector):
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
pretrained (str, optional): Path of pretrained models.
Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (dict or ConfigDict, optional): the config to control the
initialization. Defaults to None.
"""
def
__init__
(
self
,
backbone
,
neck
:
Optional
[
dict
]
=
None
,
bbox_head
:
Optional
[
dict
]
=
None
,
train_cfg
:
Optional
[
dict
]
=
None
,
test_cfg
:
Optional
[
dict
]
=
None
,
preprocess_cfg
:
Optional
[
dict
]
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
,
pretrained
:
Optional
[
str
]
=
None
)
->
None
:
super
(
SingleStage3DDetector
,
self
).
__init__
(
preprocess_cfg
=
preprocess_cfg
,
init_cfg
=
init_cfg
)
backbone
:
ConfigType
,
neck
:
OptConfigType
=
None
,
bbox_head
:
OptConfigType
=
None
,
train_cfg
:
OptConfigType
=
None
,
test_cfg
:
OptConfigType
=
None
,
data_preprocessor
:
OptConfigType
=
None
,
init_cfg
:
OptMultiConfig
=
None
)
->
None
:
super
().
__init__
(
data_processor
=
data_preprocessor
,
init_cfg
=
init_cfg
)
self
.
backbone
=
MODELS
.
build
(
backbone
)
if
neck
is
not
None
:
self
.
neck
=
MODELS
.
build
(
neck
)
...
...
@@ -45,33 +51,99 @@ class SingleStage3DDetector(Base3DDetector):
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
def
forward_dummy
(
self
,
batch_inputs
:
dict
)
->
tuple
:
"""Used for computing network flops.
def
loss
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
**
kwargs
)
->
Union
[
dict
,
list
]:
"""Calculate losses from a batch of inputs dict and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: A dictionary of loss components.
"""
x
=
self
.
extract_feat
(
batch_inputs_dict
)
losses
=
self
.
bbox_head
.
loss
(
x
,
batch_data_samples
,
**
kwargs
)
return
losses
def
predict
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
**
kwargs
)
->
SampleList
:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input images. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
See `mmdetection/tools/analysis_tools/get_flops.py`
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
x
=
self
.
extract_feat
(
batch_inputs
[
'points'
])
try
:
sample_mod
=
self
.
train_cfg
.
sample_mod
outs
=
self
.
bbox_head
(
x
,
sample_mod
)
except
AttributeError
:
outs
=
self
.
bbox_head
(
x
)
return
outs
def
extract_feat
(
self
,
points
:
List
[
torch
.
Tensor
])
->
list
:
x
=
self
.
extract_feat
(
batch_inputs_dict
)
results_list
=
self
.
bbox_head
.
predict
(
x
,
batch_data_samples
,
**
kwargs
)
predictions
=
self
.
convert_to_datasample
(
results_list
)
return
predictions
def
_forward
(
self
,
batch_inputs_dict
:
dict
,
data_samples
:
OptSampleList
=
None
,
**
kwargs
)
->
Tuple
[
List
[
torch
.
Tensor
]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward.
"""
x
=
self
.
extract_feat
(
batch_inputs_dict
)
results
=
self
.
bbox_head
.
forward
(
x
)
return
results
def
extract_feat
(
self
,
batch_inputs_dict
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
]:
"""Directly extract features from the backbone+neck.
Args:
points (
List[
torch.Tensor
]
): Input points.
points (torch.Tensor): Input points.
"""
x
=
self
.
backbone
(
points
[
0
])
points
=
batch_inputs_dict
[
'points'
]
stack_points
=
torch
.
stack
(
points
)
x
=
self
.
backbone
(
stack_points
)
if
self
.
with_neck
:
x
=
self
.
neck
(
x
)
return
x
def
extract_feats
(
self
,
batch_inputs_dict
:
dict
)
->
list
:
"""Extract features of multiple samples."""
return
[
self
.
extract_feat
([
points
])
for
points
in
batch_inputs_dict
[
'points'
]
]
tests/test_models/test_preprocessors/test_data_preprocessor.py
0 → 100644
View file @
ff1e5b4e
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
torch
from
mmdet3d.core
import
Det3DDataSample
from
mmdet3d.models.data_preprocessors
import
Det3DDataPreprocessor
class
TestDet3DDataPreprocessor
(
TestCase
):
def
test_init
(
self
):
# test mean is None
processor
=
Det3DDataPreprocessor
()
self
.
assertTrue
(
not
hasattr
(
processor
,
'mean'
))
self
.
assertTrue
(
processor
.
_enable_normalize
is
False
)
# test mean is not None
processor
=
Det3DDataPreprocessor
(
mean
=
[
0
,
0
,
0
],
std
=
[
1
,
1
,
1
])
self
.
assertTrue
(
hasattr
(
processor
,
'mean'
))
self
.
assertTrue
(
hasattr
(
processor
,
'std'
))
self
.
assertTrue
(
processor
.
_enable_normalize
)
# please specify both mean and std
with
self
.
assertRaises
(
AssertionError
):
Det3DDataPreprocessor
(
mean
=
[
0
,
0
,
0
])
# bgr2rgb and rgb2bgr cannot be set to True at the same time
with
self
.
assertRaises
(
AssertionError
):
Det3DDataPreprocessor
(
bgr_to_rgb
=
True
,
rgb_to_bgr
=
True
)
def
test_forward
(
self
):
processor
=
Det3DDataPreprocessor
(
mean
=
[
0
,
0
,
0
],
std
=
[
1
,
1
,
1
])
points
=
torch
.
randn
((
5000
,
3
))
image
=
torch
.
randint
(
0
,
256
,
(
3
,
11
,
10
))
inputs_dict
=
dict
(
points
=
points
,
img
=
image
)
data
=
[{
'inputs'
:
inputs_dict
,
'data_sample'
:
Det3DDataSample
()}]
inputs
,
data_samples
=
processor
(
data
)
self
.
assertEqual
(
inputs
[
'imgs'
].
shape
,
(
1
,
3
,
11
,
10
))
self
.
assertEqual
(
len
(
inputs
[
'points'
]),
1
)
self
.
assertEqual
(
len
(
data_samples
),
1
)
# test image channel_conversion
processor
=
Det3DDataPreprocessor
(
mean
=
[
0.
,
0.
,
0.
],
std
=
[
1.
,
1.
,
1.
],
bgr_to_rgb
=
True
)
inputs
,
data_samples
=
processor
(
data
)
self
.
assertEqual
(
inputs
[
'imgs'
].
shape
,
(
1
,
3
,
11
,
10
))
self
.
assertEqual
(
len
(
data_samples
),
1
)
# test image padding
data
=
[{
'inputs'
:
{
'points'
:
torch
.
randn
((
5000
,
3
)),
'img'
:
torch
.
randint
(
0
,
256
,
(
3
,
10
,
11
))
}
},
{
'inputs'
:
{
'points'
:
torch
.
randn
((
5000
,
3
)),
'img'
:
torch
.
randint
(
0
,
256
,
(
3
,
9
,
14
))
}
}]
processor
=
Det3DDataPreprocessor
(
mean
=
[
0.
,
0.
,
0.
],
std
=
[
1.
,
1.
,
1.
],
bgr_to_rgb
=
True
)
inputs
,
data_samples
=
processor
(
data
)
self
.
assertEqual
(
inputs
[
'imgs'
].
shape
,
(
2
,
3
,
10
,
14
))
self
.
assertIsNone
(
data_samples
)
# test pad_size_divisor
data
=
[{
'inputs'
:
{
'points'
:
torch
.
randn
((
5000
,
3
)),
'img'
:
torch
.
randint
(
0
,
256
,
(
3
,
10
,
11
))
},
'data_sample'
:
Det3DDataSample
()
},
{
'inputs'
:
{
'points'
:
torch
.
randn
((
5000
,
3
)),
'img'
:
torch
.
randint
(
0
,
256
,
(
3
,
9
,
24
))
},
'data_sample'
:
Det3DDataSample
()
}]
processor
=
Det3DDataPreprocessor
(
mean
=
[
0.
,
0.
,
0.
],
std
=
[
1.
,
1.
,
1.
],
pad_size_divisor
=
5
)
inputs
,
data_samples
=
processor
(
data
)
self
.
assertEqual
(
inputs
[
'imgs'
].
shape
,
(
2
,
3
,
10
,
25
))
self
.
assertEqual
(
len
(
data_samples
),
2
)
for
data_sample
,
expected_shape
in
zip
(
data_samples
,
[(
10
,
15
),
(
10
,
25
)]):
self
.
assertEqual
(
data_sample
.
pad_shape
,
expected_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment