Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdeploy
Commits
e4fb2aa4
"detection/ops_dcnv3/vscode:/vscode.git/clone" did not exist on "c4552f794aab15e56a00ccb06747e3fa6b8bec38"
Commit
e4fb2aa4
authored
Jun 25, 2025
by
limm
Browse files
add test_mmdet3d
parent
481f872d
Pipeline
#2822
canceled with stages
Changes
26
Pipelines
2
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
700 additions
and
0 deletions
+700
-0
tests/test_codebase/test_mmdet3d/data/smoke_dla34_dlaneck_gn-all_4xb8-6x_kitti-mono3d.py
...d/data/smoke_dla34_dlaneck_gn-all_4xb8-6x_kitti-mono3d.py
+61
-0
tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py
tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py
+245
-0
tests/test_codebase/test_mmdet3d/test_mono_detection.py
tests/test_codebase/test_mmdet3d/test_mono_detection.py
+97
-0
tests/test_codebase/test_mmdet3d/test_mono_detection_model.py
...s/test_codebase/test_mmdet3d/test_mono_detection_model.py
+97
-0
tests/test_codebase/test_mmdet3d/test_voxel_detection.py
tests/test_codebase/test_mmdet3d/test_voxel_detection.py
+97
-0
tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py
.../test_codebase/test_mmdet3d/test_voxel_detection_model.py
+103
-0
No files found.
tests/test_codebase/test_mmdet3d/data/smoke_dla34_dlaneck_gn-all_4xb8-6x_kitti-mono3d.py
0 → 100755
View file @
e4fb2aa4
# Copyright (c) OpenMMLab. All rights reserved.
_base_
=
[
'kitti-mono3d.py'
,
'smoke.py'
,
'default_runtime.py'
]
backend_args
=
None
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFileMono3D'
,
backend_args
=
backend_args
),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox
=
True
,
with_label
=
True
,
with_attr_label
=
False
,
with_bbox_3d
=
True
,
with_label_3d
=
True
,
with_bbox_depth
=
True
),
dict
(
type
=
'RandomFlip3D'
,
flip_ratio_bev_horizontal
=
0.5
),
dict
(
type
=
'RandomShiftScale'
,
shift_scale
=
(
0.2
,
0.4
),
aug_prob
=
0.3
),
dict
(
type
=
'AffineResize'
,
img_scale
=
(
1280
,
384
),
down_ratio
=
4
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_bboxes_labels'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
,
'centers_2d'
,
'depths'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFileMono3D'
,
backend_args
=
backend_args
),
dict
(
type
=
'AffineResize'
,
img_scale
=
(
1280
,
384
),
down_ratio
=
4
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'img'
])
]
train_dataloader
=
dict
(
batch_size
=
8
,
num_workers
=
4
,
dataset
=
dict
(
pipeline
=
train_pipeline
))
test_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
test_pipeline
))
val_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
test_pipeline
))
# training schedule for 6x
max_epochs
=
72
train_cfg
=
dict
(
type
=
'EpochBasedTrainLoop'
,
max_epochs
=
max_epochs
,
val_interval
=
5
)
val_cfg
=
dict
(
type
=
'ValLoop'
)
test_cfg
=
dict
(
type
=
'TestLoop'
)
# learning rate
param_scheduler
=
[
dict
(
type
=
'MultiStepLR'
,
begin
=
0
,
end
=
max_epochs
,
by_epoch
=
True
,
milestones
=
[
50
],
gamma
=
0.1
)
]
# optimizer
optim_wrapper
=
dict
(
type
=
'OptimWrapper'
,
optimizer
=
dict
(
type
=
'Adam'
,
lr
=
2.5e-4
),
clip_grad
=
None
)
find_unused_parameters
=
True
tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py
0 → 100644
View file @
e4fb2aa4
# Copyright (c) OpenMMLab. All rights reserved.
import
mmengine
import
numpy
as
np
import
pytest
import
torch
from
mmdeploy.apis
import
build_task_processor
from
mmdeploy.codebase
import
import_codebase
from
mmdeploy.utils
import
Backend
,
Codebase
,
Task
,
load_config
from
mmdeploy.utils.test
import
WrapModel
,
check_backend
,
get_rewrite_outputs
try
:
import_codebase
(
Codebase
.
MMDET3D
)
except
ImportError
:
pytest
.
skip
(
f
'
{
Codebase
.
MMDET3D
}
is not installed.'
,
allow_module_level
=
True
)
model_cfg
=
load_config
(
'tests/test_codebase/test_mmdet3d/data/model_cfg.py'
)[
0
]
def
get_pillar_encoder
():
from
mmdet3d.models.voxel_encoders
import
PillarFeatureNet
model
=
PillarFeatureNet
(
in_channels
=
4
,
feat_channels
=
(
64
,
),
with_distance
=
False
,
with_cluster_center
=
True
,
with_voxel_center
=
True
,
voxel_size
=
(
0.2
,
0.2
,
4
),
point_cloud_range
=
(
0
,
-
40
,
-
3
,
70.4
,
40
,
1
),
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
mode
=
'max'
)
model
.
requires_grad_
(
False
)
return
model
def
get_pointpillars_scatter
():
from
mmdet3d.models.middle_encoders
import
PointPillarsScatter
model
=
PointPillarsScatter
(
in_channels
=
64
,
output_shape
=
(
16
,
16
))
model
.
requires_grad_
(
False
)
return
model
@
pytest
.
mark
.
parametrize
(
'backend_type'
,
[
Backend
.
ONNXRUNTIME
])
def
test_pillar_encoder
(
backend_type
:
Backend
):
check_backend
(
backend_type
,
True
)
model
=
get_pillar_encoder
()
model
.
cpu
().
eval
()
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
backend_type
.
value
),
onnx_config
=
dict
(
input_shape
=
None
,
input_names
=
[
'features'
,
'num_points'
,
'coors'
],
output_names
=
[
'outputs'
]),
codebase_config
=
dict
(
type
=
Codebase
.
MMDET3D
.
value
,
task
=
Task
.
VOXEL_DETECTION
.
value
)))
features
=
torch
.
rand
(
3945
,
32
,
4
)
*
100
num_points
=
torch
.
randint
(
0
,
32
,
(
3945
,
),
dtype
=
torch
.
int32
)
coors
=
torch
.
randint
(
0
,
10
,
(
3945
,
4
),
dtype
=
torch
.
int32
)
model_outputs
=
model
.
forward
(
features
,
num_points
,
coors
)
wrapped_model
=
WrapModel
(
model
,
'forward'
)
rewrite_inputs
=
{
'features'
:
features
,
'num_points'
:
num_points
,
'coors'
:
coors
}
rewrite_outputs
,
is_backend_output
=
get_rewrite_outputs
(
wrapped_model
=
wrapped_model
,
model_inputs
=
rewrite_inputs
,
deploy_cfg
=
deploy_cfg
)
if
isinstance
(
rewrite_outputs
,
dict
):
rewrite_outputs
=
rewrite_outputs
[
'output'
]
if
isinstance
(
rewrite_outputs
,
list
):
rewrite_outputs
=
rewrite_outputs
[
0
]
assert
np
.
allclose
(
model_outputs
.
shape
,
rewrite_outputs
.
shape
,
rtol
=
1e-03
,
atol
=
1e-03
)
@
pytest
.
mark
.
parametrize
(
'backend_type'
,
[
Backend
.
ONNXRUNTIME
])
def
test_pointpillars_scatter
(
backend_type
:
Backend
):
check_backend
(
backend_type
,
True
)
model
=
get_pointpillars_scatter
()
model
.
cpu
().
eval
()
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
backend_type
.
value
),
onnx_config
=
dict
(
input_shape
=
None
,
input_names
=
[
'voxel_features'
,
'coors'
],
output_names
=
[
'outputs'
]),
codebase_config
=
dict
(
type
=
Codebase
.
MMDET3D
.
value
,
task
=
Task
.
VOXEL_DETECTION
.
value
)))
voxel_features
=
torch
.
rand
(
16
*
16
,
64
)
*
100
coors
=
torch
.
randint
(
0
,
10
,
(
16
*
16
,
4
),
dtype
=
torch
.
int32
)
model_outputs
=
model
.
forward_batch
(
voxel_features
,
coors
,
1
)
wrapped_model
=
WrapModel
(
model
,
'forward_batch'
)
rewrite_inputs
=
{
'voxel_features'
:
voxel_features
,
'coors'
:
coors
}
rewrite_outputs
,
is_backend_output
=
get_rewrite_outputs
(
wrapped_model
=
wrapped_model
,
model_inputs
=
rewrite_inputs
,
deploy_cfg
=
deploy_cfg
)
if
isinstance
(
rewrite_outputs
,
list
):
rewrite_outputs
=
rewrite_outputs
[
0
]
assert
np
.
allclose
(
model_outputs
.
shape
,
rewrite_outputs
.
shape
,
rtol
=
1e-03
,
atol
=
1e-03
)
def
get_centerpoint
():
from
mmdet3d.models.detectors.centerpoint
import
CenterPoint
model
=
CenterPoint
(
**
model_cfg
.
centerpoint_model
)
model
.
requires_grad_
(
False
)
return
model
def
get_centerpoint_head
():
from
mmdet3d.models
import
builder
model_cfg
.
centerpoint_model
.
pts_bbox_head
.
test_cfg
=
model_cfg
.
\
centerpoint_model
.
test_cfg
head
=
builder
.
build_head
(
model_cfg
.
centerpoint_model
.
pts_bbox_head
)
head
.
requires_grad_
(
False
)
return
head
@
pytest
.
mark
.
parametrize
(
'backend_type'
,
[
Backend
.
ONNXRUNTIME
])
def
test_pointpillars
(
backend_type
:
Backend
):
from
mmdeploy.core
import
RewriterContext
check_backend
(
backend_type
,
True
)
model_cfg
=
load_config
(
'tests/test_codebase/test_mmdet3d/data/model_cfg.py'
)[
0
]
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
backend_type
.
value
),
onnx_config
=
dict
(
input_shape
=
None
,
opset_version
=
11
,
input_names
=
[
'voxels'
,
'num_points'
,
'coors'
],
output_names
=
[
'outputs'
]),
codebase_config
=
dict
(
type
=
Codebase
.
MMDET3D
.
value
,
task
=
Task
.
VOXEL_DETECTION
.
value
)))
task_processor
=
build_task_processor
(
model_cfg
,
deploy_cfg
,
'cpu'
)
model
=
task_processor
.
build_pytorch_model
(
None
)
model
.
eval
()
preproc
=
task_processor
.
build_data_preprocessor
()
_
,
data
=
task_processor
.
create_input
(
pcd
=
'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin'
,
data_preprocessor
=
preproc
)
with
RewriterContext
(
cfg
=
deploy_cfg
,
backend
=
deploy_cfg
.
backend_config
.
type
,
opset
=
deploy_cfg
.
onnx_config
.
opset_version
):
outputs
=
model
.
forward
(
*
data
)
assert
len
(
outputs
)
==
3
def
get_pointpillars_nus
():
from
mmdet3d.models.detectors
import
MVXFasterRCNN
model
=
MVXFasterRCNN
(
**
model_cfg
.
pointpillars_nus_model
)
model
.
requires_grad_
(
False
)
return
model
@
pytest
.
mark
.
parametrize
(
'backend_type'
,
[
Backend
.
ONNXRUNTIME
])
def
test_centerpoint
(
backend_type
:
Backend
):
from
mmdeploy.core
import
RewriterContext
check_backend
(
backend_type
,
True
)
centerpoint_model_cfg
=
load_config
(
'tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py'
# noqa: E501
)[
0
]
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
backend_type
.
value
),
onnx_config
=
dict
(
input_shape
=
None
,
opset_version
=
11
,
input_names
=
[
'voxels'
,
'num_points'
,
'coors'
],
output_names
=
[
'outputs'
]),
codebase_config
=
dict
(
type
=
Codebase
.
MMDET3D
.
value
,
task
=
Task
.
VOXEL_DETECTION
.
value
)))
task_processor
=
build_task_processor
(
centerpoint_model_cfg
,
deploy_cfg
,
'cpu'
)
model
=
task_processor
.
build_pytorch_model
(
None
)
model
.
eval
()
preproc
=
task_processor
.
build_data_preprocessor
()
_
,
data
=
task_processor
.
create_input
(
pcd
=
# noqa: E251
'tests/test_codebase/test_mmdet3d/data/nuscenes/n008-2018-09-18-12-07-26-0400__LIDAR_TOP__1537287083900561.pcd.bin'
,
# noqa: E501
data_preprocessor
=
preproc
)
with
RewriterContext
(
cfg
=
deploy_cfg
,
backend
=
deploy_cfg
.
backend_config
.
type
,
opset
=
deploy_cfg
.
onnx_config
.
opset_version
):
outputs
=
model
.
forward
(
data
)
assert
outputs
is
not
None
@
pytest
.
mark
.
parametrize
(
'backend_type'
,
[
Backend
.
ONNXRUNTIME
])
def
test_smoke
(
backend_type
:
Backend
):
from
mmdeploy.core
import
RewriterContext
check_backend
(
backend_type
,
True
)
model_cfg
=
load_config
(
'tests/test_codebase/test_mmdet3d/data/smoke_dla34_dlaneck_gn-all_4xb8-6x_kitti-mono3d.py'
# noqa: E501
)[
0
]
# noqa: E501
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
backend_type
.
value
),
onnx_config
=
dict
(
input_shape
=
None
,
opset_version
=
11
,
input_names
=
[
'input'
],
output_names
=
[
'cls_score'
,
'bbox_pred'
]),
codebase_config
=
dict
(
type
=
Codebase
.
MMDET3D
.
value
,
task
=
Task
.
MONO_DETECTION
.
value
)))
task_processor
=
build_task_processor
(
model_cfg
,
deploy_cfg
,
'cpu'
)
model
=
task_processor
.
build_pytorch_model
(
None
)
model
.
eval
()
preproc
=
task_processor
.
build_data_preprocessor
()
_
,
data
=
task_processor
.
create_input
(
pcd
=
# noqa: E251
'tests/test_codebase/test_mmdet3d/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl'
,
# noqa: E501
data_preprocessor
=
preproc
)
with
RewriterContext
(
cfg
=
deploy_cfg
,
backend
=
deploy_cfg
.
backend_config
.
type
,
opset
=
deploy_cfg
.
onnx_config
.
opset_version
):
cls_score
,
bbox_pred
=
model
.
forward
(
data
)
assert
len
(
cls_score
)
==
1
and
len
(
bbox_pred
)
==
1
tests/test_codebase/test_mmdet3d/test_mono_detection.py
0 → 100755
View file @
e4fb2aa4
# Copyright (c) OpenMMLab. All rights reserved.
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
import
mmengine
import
pytest
import
torch
import
mmdeploy.backend.onnxruntime
as
ort_apis
from
mmdeploy.apis
import
build_task_processor
from
mmdeploy.codebase
import
import_codebase
from
mmdeploy.utils
import
Codebase
,
load_config
from
mmdeploy.utils.test
import
DummyModel
,
SwitchBackendWrapper
try
:
import_codebase
(
Codebase
.
MMDET3D
)
except
ImportError
:
pytest
.
skip
(
f
'
{
Codebase
.
MMDET3D
}
is not installed.'
,
allow_module_level
=
True
)
model_cfg_path
=
'tests/test_codebase/test_mmdet3d/data/smoke_dla34_dlaneck_gn-all_4xb8-6x_kitti-mono3d.py'
# noqa: E501
pcd_path
=
'tests/test_codebase/test_mmdet3d/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl'
# noqa: E501
model_cfg
=
load_config
(
model_cfg_path
)[
0
]
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
'onnxruntime'
),
codebase_config
=
dict
(
type
=
'mmdet3d'
,
task
=
'MonoDetection'
),
onnx_config
=
dict
(
type
=
'onnx'
,
export_params
=
True
,
keep_initializers_as_inputs
=
False
,
opset_version
=
11
,
input_shape
=
None
,
input_names
=
[
'input'
],
output_names
=
[
'cls_score'
,
'bbox_pred'
])))
onnx_file
=
NamedTemporaryFile
(
suffix
=
'.onnx'
).
name
task_processor
=
None
@
pytest
.
fixture
(
autouse
=
True
)
def
init_task_processor
():
global
task_processor
task_processor
=
build_task_processor
(
model_cfg
,
deploy_cfg
,
'cpu'
)
def
test_build_pytorch_model
():
from
mmdet3d.models
import
SingleStageMono3DDetector
model
=
task_processor
.
build_pytorch_model
(
None
)
assert
isinstance
(
model
,
SingleStageMono3DDetector
)
@
pytest
.
fixture
def
backend_model
():
from
mmdeploy.backend.onnxruntime
import
ORTWrapper
ort_apis
.
__dict__
.
update
({
'ORTWrapper'
:
ORTWrapper
})
wrapper
=
SwitchBackendWrapper
(
ORTWrapper
)
wrapper
.
set
(
outputs
=
{
'cls_score'
:
torch
.
rand
(
1
,
3
,
96
,
320
),
'bbox_pred'
:
torch
.
rand
(
1
,
8
,
96
,
320
),
})
yield
task_processor
.
build_backend_model
([
''
])
wrapper
.
recover
()
def
test_build_backend_model
(
backend_model
):
from
mmdeploy.codebase.mmdet3d.deploy.mono_detection_model
import
\
MonoDetectionModel
assert
isinstance
(
backend_model
,
MonoDetectionModel
)
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cpu'
,
'cuda:0'
])
def
test_create_input
(
device
):
if
device
==
'cuda:0'
and
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'cuda is not available'
)
original_device
=
task_processor
.
device
task_processor
.
device
=
device
inputs
=
task_processor
.
create_input
(
pcd_path
)
assert
len
(
inputs
)
==
2
task_processor
.
device
=
original_device
@
pytest
.
mark
.
skipif
(
reason
=
'Only support GPU test'
,
condition
=
not
torch
.
cuda
.
is_available
())
def
test_single_gpu_test_and_evaluate
():
task_processor
.
device
=
'cuda:0'
# Prepare dummy model
model
=
DummyModel
(
outputs
=
[
torch
.
rand
([
1
,
3
,
96
,
320
]),
torch
.
rand
([
1
,
8
,
96
,
320
])])
assert
model
is
not
None
# Run test
with
TemporaryDirectory
()
as
dir
:
task_processor
.
build_test_runner
(
model
,
dir
)
tests/test_codebase/test_mmdet3d/test_mono_detection_model.py
0 → 100755
View file @
e4fb2aa4
# Copyright (c) OpenMMLab. All rights reserved.
import
mmengine
import
pytest
import
torch
import
mmdeploy.backend.onnxruntime
as
ort_apis
from
mmdeploy.codebase
import
import_codebase
from
mmdeploy.utils
import
Backend
,
Codebase
from
mmdeploy.utils.test
import
SwitchBackendWrapper
,
backend_checker
try
:
import_codebase
(
Codebase
.
MMDET3D
)
except
ImportError
:
pytest
.
skip
(
f
'
{
Codebase
.
MMDET3D
}
is not installed.'
,
allow_module_level
=
True
)
from
mmdeploy.codebase.mmdet3d.deploy.mono_detection_model
import
\
MonoDetectionModel
nuscenes_pcd_path
=
'tests/test_codebase/test_mmdet3d/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl'
# noqa: E501
somke_model_cfg
=
'tests/test_codebase/test_mmdet3d/data/smoke_dla34_dlaneck_gn-all_4xb8-6x_kitti-mono3d.py'
# noqa: E501
@
backend_checker
(
Backend
.
ONNXRUNTIME
)
class
TestMonoDetectionModel
:
@
classmethod
def
setup_class
(
cls
):
# force add backend wrapper regardless of plugins
from
mmdeploy.backend.onnxruntime
import
ORTWrapper
ort_apis
.
__dict__
.
update
({
'ORTWrapper'
:
ORTWrapper
})
# simplify backend inference
cls
.
wrapper
=
SwitchBackendWrapper
(
ORTWrapper
)
cls
.
outputs
=
{
'cls_score'
:
torch
.
rand
(
1
,
3
,
96
,
320
),
'bbox_pred'
:
torch
.
rand
(
1
,
8
,
96
,
320
),
}
cls
.
wrapper
.
set
(
outputs
=
cls
.
outputs
)
deploy_cfg
=
mmengine
.
Config
({
'onnx_config'
:
{
'input_names'
:
[
'input'
],
'output_names'
:
[
'cls_score'
,
'bbox_pred'
],
'opset_version'
:
11
},
'backend_config'
:
{
'type'
:
'onnxruntime'
}
})
from
mmdeploy.utils
import
load_config
model_cfg_path
=
somke_model_cfg
model_cfg
=
load_config
(
model_cfg_path
)[
0
]
cls
.
end2end_model
=
MonoDetectionModel
(
Backend
.
ONNXRUNTIME
,
[
''
],
device
=
'cuda'
,
deploy_cfg
=
deploy_cfg
,
model_cfg
=
model_cfg
)
@
classmethod
def
teardown_class
(
cls
):
cls
.
wrapper
.
recover
()
@
pytest
.
mark
.
skipif
(
reason
=
'Only support GPU test'
,
condition
=
not
torch
.
cuda
.
is_available
())
def
test_forward_and_show_result
(
self
):
inputs
=
{
'imgs'
:
torch
.
rand
((
1
,
3
,
384
,
1280
)),
}
results
=
self
.
end2end_model
.
forward
(
inputs
=
inputs
)
assert
results
is
not
None
@
backend_checker
(
Backend
.
ONNXRUNTIME
)
def
test_build_mono_detection_model
():
from
mmdeploy.utils
import
load_config
model_cfg_path
=
somke_model_cfg
model_cfg
=
load_config
(
model_cfg_path
)[
0
]
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
Backend
.
ONNXRUNTIME
.
value
),
onnx_config
=
dict
(
output_names
=
[
'cls_score'
,
'bbox_pred'
]),
codebase_config
=
dict
(
type
=
Codebase
.
MMDET3D
.
value
)))
from
mmdeploy.backend.onnxruntime
import
ORTWrapper
ort_apis
.
__dict__
.
update
({
'ORTWrapper'
:
ORTWrapper
})
# simplify backend inference
with
SwitchBackendWrapper
(
ORTWrapper
)
as
wrapper
:
wrapper
.
set
(
model_cfg
=
model_cfg
,
deploy_cfg
=
deploy_cfg
)
from
mmdeploy.codebase.mmdet3d.deploy.mono_detection_model
import
(
MonoDetectionModel
,
build_mono_detection_model
)
monodetector
=
build_mono_detection_model
([
''
],
model_cfg
=
model_cfg
,
deploy_cfg
=
deploy_cfg
,
device
=
'cpu'
)
assert
isinstance
(
monodetector
,
MonoDetectionModel
)
tests/test_codebase/test_mmdet3d/test_voxel_detection.py
0 → 100644
View file @
e4fb2aa4
# Copyright (c) OpenMMLab. All rights reserved.
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
import
mmengine
import
pytest
import
torch
import
mmdeploy.backend.onnxruntime
as
ort_apis
from
mmdeploy.apis
import
build_task_processor
from
mmdeploy.codebase
import
import_codebase
from
mmdeploy.utils
import
Codebase
,
load_config
from
mmdeploy.utils.test
import
DummyModel
,
SwitchBackendWrapper
try
:
import_codebase
(
Codebase
.
MMDET3D
)
except
ImportError
:
pytest
.
skip
(
f
'
{
Codebase
.
MMDET3D
}
is not installed.'
,
allow_module_level
=
True
)
model_cfg_path
=
'tests/test_codebase/test_mmdet3d/data/model_cfg.py'
pcd_path
=
'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin'
model_cfg
=
load_config
(
model_cfg_path
)[
0
]
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
'onnxruntime'
),
codebase_config
=
dict
(
type
=
'mmdet3d'
,
task
=
'VoxelDetection'
),
onnx_config
=
dict
(
type
=
'onnx'
,
export_params
=
True
,
keep_initializers_as_inputs
=
False
,
opset_version
=
11
,
input_shape
=
None
,
input_names
=
[
'voxels'
,
'num_points'
,
'coors'
],
output_names
=
[
'cls_score'
,
'bbox_pred'
,
'dir_cls_pred'
])))
onnx_file
=
NamedTemporaryFile
(
suffix
=
'.onnx'
).
name
task_processor
=
None
@
pytest
.
fixture
(
autouse
=
True
)
def
init_task_processor
():
global
task_processor
task_processor
=
build_task_processor
(
model_cfg
,
deploy_cfg
,
'cpu'
)
def
test_build_pytorch_model
():
from
mmdet3d.models
import
Base3DDetector
model
=
task_processor
.
build_pytorch_model
(
None
)
assert
isinstance
(
model
,
Base3DDetector
)
@
pytest
.
fixture
def
backend_model
():
from
mmdeploy.backend.onnxruntime
import
ORTWrapper
ort_apis
.
__dict__
.
update
({
'ORTWrapper'
:
ORTWrapper
})
wrapper
=
SwitchBackendWrapper
(
ORTWrapper
)
wrapper
.
set
(
outputs
=
{
'cls_score'
:
torch
.
rand
(
1
,
18
,
32
,
32
),
'bbox_pred'
:
torch
.
rand
(
1
,
42
,
32
,
32
),
'dir_cls_pred'
:
torch
.
rand
(
1
,
12
,
32
,
32
)
})
yield
task_processor
.
build_backend_model
([
''
])
wrapper
.
recover
()
def
test_build_backend_model
(
backend_model
):
from
mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model
import
\
VoxelDetectionModel
assert
isinstance
(
backend_model
,
VoxelDetectionModel
)
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cpu'
,
'cuda:0'
])
def
test_create_input
(
device
):
if
device
==
'cuda:0'
and
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'cuda is not available'
)
original_device
=
task_processor
.
device
task_processor
.
device
=
device
inputs
=
task_processor
.
create_input
(
pcd_path
)
assert
len
(
inputs
)
==
2
task_processor
.
device
=
original_device
@
pytest
.
mark
.
skipif
(
reason
=
'Only support GPU test'
,
condition
=
not
torch
.
cuda
.
is_available
())
def
test_single_gpu_test_and_evaluate
():
task_processor
.
device
=
'cuda:0'
# Prepare dummy model
model
=
DummyModel
(
outputs
=
[
torch
.
rand
([
1
,
10
,
5
]),
torch
.
rand
([
1
,
10
])])
assert
model
is
not
None
# Run test
with
TemporaryDirectory
()
as
dir
:
task_processor
.
build_test_runner
(
model
,
dir
)
tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py
0 → 100644
View file @
e4fb2aa4
# Copyright (c) OpenMMLab. All rights reserved.
import
mmengine
import
pytest
import
torch
import
mmdeploy.backend.onnxruntime
as
ort_apis
from
mmdeploy.codebase
import
import_codebase
from
mmdeploy.utils
import
Backend
,
Codebase
from
mmdeploy.utils.test
import
SwitchBackendWrapper
,
backend_checker
try
:
import_codebase
(
Codebase
.
MMDET3D
)
except
ImportError
:
pytest
.
skip
(
f
'
{
Codebase
.
MMDET3D
}
is not installed.'
,
allow_module_level
=
True
)
from
mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model
import
\
VoxelDetectionModel
pcd_path
=
'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin'
model_cfg
=
'tests/test_codebase/test_mmdet3d/data/model_cfg.py'
@
backend_checker
(
Backend
.
ONNXRUNTIME
)
class
TestVoxelDetectionModel
:
@
classmethod
def
setup_class
(
cls
):
# force add backend wrapper regardless of plugins
from
mmdeploy.backend.onnxruntime
import
ORTWrapper
ort_apis
.
__dict__
.
update
({
'ORTWrapper'
:
ORTWrapper
})
# simplify backend inference
cls
.
wrapper
=
SwitchBackendWrapper
(
ORTWrapper
)
cls
.
outputs
=
{
'cls_score0'
:
torch
.
rand
(
1
,
18
,
32
,
32
),
'bbox_pred0'
:
torch
.
rand
(
1
,
42
,
32
,
32
),
'dir_cls_pred0'
:
torch
.
rand
(
1
,
12
,
32
,
32
)
}
cls
.
wrapper
.
set
(
outputs
=
cls
.
outputs
)
deploy_cfg
=
mmengine
.
Config
({
'onnx_config'
:
{
'input_names'
:
[
'voxels'
,
'num_points'
,
'coors'
],
'output_names'
:
[
'cls_score0'
,
'bbox_pred0'
,
'dir_cls_pred0'
],
'opset_version'
:
11
},
'backend_config'
:
{
'type'
:
'onnxruntime'
}
})
from
mmdeploy.utils
import
load_config
model_cfg_path
=
'tests/test_codebase/test_mmdet3d/data/model_cfg.py'
model_cfg
=
load_config
(
model_cfg_path
)[
0
]
cls
.
end2end_model
=
VoxelDetectionModel
(
Backend
.
ONNXRUNTIME
,
[
''
],
device
=
'cuda'
,
deploy_cfg
=
deploy_cfg
,
model_cfg
=
model_cfg
)
@
classmethod
def
teardown_class
(
cls
):
cls
.
wrapper
.
recover
()
@
pytest
.
mark
.
skipif
(
reason
=
'Only support GPU test'
,
condition
=
not
torch
.
cuda
.
is_available
())
def
test_forward_and_show_result
(
self
):
inputs
=
{
'voxels'
:
{
'voxels'
:
torch
.
rand
((
3945
,
32
,
4
)),
'num_points'
:
torch
.
ones
((
3945
),
dtype
=
torch
.
int32
),
'coors'
:
torch
.
ones
((
3945
,
4
),
dtype
=
torch
.
int32
)
}
}
results
=
self
.
end2end_model
.
forward
(
inputs
=
inputs
)
assert
results
is
not
None
@
backend_checker
(
Backend
.
ONNXRUNTIME
)
def
test_build_voxel_detection_model
():
from
mmdeploy.utils
import
load_config
model_cfg_path
=
'tests/test_codebase/test_mmdet3d/data/model_cfg.py'
model_cfg
=
load_config
(
model_cfg_path
)[
0
]
deploy_cfg
=
mmengine
.
Config
(
dict
(
backend_config
=
dict
(
type
=
Backend
.
ONNXRUNTIME
.
value
),
onnx_config
=
dict
(
output_names
=
[
'cls_score0'
,
'bbox_pred0'
,
'dir_cls_pred0'
]),
codebase_config
=
dict
(
type
=
Codebase
.
MMDET3D
.
value
)))
from
mmdeploy.backend.onnxruntime
import
ORTWrapper
ort_apis
.
__dict__
.
update
({
'ORTWrapper'
:
ORTWrapper
})
# simplify backend inference
with
SwitchBackendWrapper
(
ORTWrapper
)
as
wrapper
:
wrapper
.
set
(
model_cfg
=
model_cfg
,
deploy_cfg
=
deploy_cfg
)
from
mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model
import
(
VoxelDetectionModel
,
build_voxel_detection_model
)
voxeldetector
=
build_voxel_detection_model
([
''
],
model_cfg
=
model_cfg
,
deploy_cfg
=
deploy_cfg
,
device
=
'cpu'
)
assert
isinstance
(
voxeldetector
,
VoxelDetectionModel
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment