Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
32a4328b
"vscode:/vscode.git/clone" did not exist on "150a18521282b4cf8408950c63dfdea57d7e76dd"
Unverified
Commit
32a4328b
authored
Feb 24, 2022
by
Wenwei Zhang
Committed by
GitHub
Feb 24, 2022
Browse files
Bump version to V1.0.0rc0
Bump version to V1.0.0rc0
parents
86cc487c
a8817998
Changes
414
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
652 additions
and
130 deletions
+652
-130
tests/test_utils/test_setup_env.py
tests/test_utils/test_setup_env.py
+68
-0
tests/test_utils/test_utils.py
tests/test_utils/test_utils.py
+277
-1
tools/analysis_tools/analyze_logs.py
tools/analysis_tools/analyze_logs.py
+2
-1
tools/analysis_tools/benchmark.py
tools/analysis_tools/benchmark.py
+1
-0
tools/analysis_tools/get_flops.py
tools/analysis_tools/get_flops.py
+1
-0
tools/create_data.py
tools/create_data.py
+6
-4
tools/create_data.sh
tools/create_data.sh
+5
-6
tools/data_converter/create_gt_database.py
tools/data_converter/create_gt_database.py
+10
-9
tools/data_converter/indoor_converter.py
tools/data_converter/indoor_converter.py
+7
-5
tools/data_converter/kitti_converter.py
tools/data_converter/kitti_converter.py
+43
-27
tools/data_converter/kitti_data_utils.py
tools/data_converter/kitti_data_utils.py
+22
-1
tools/data_converter/lyft_converter.py
tools/data_converter/lyft_converter.py
+14
-11
tools/data_converter/lyft_data_fixer.py
tools/data_converter/lyft_data_fixer.py
+2
-1
tools/data_converter/nuimage_converter.py
tools/data_converter/nuimage_converter.py
+2
-1
tools/data_converter/nuscenes_converter.py
tools/data_converter/nuscenes_converter.py
+22
-18
tools/data_converter/s3dis_data_utils.py
tools/data_converter/s3dis_data_utils.py
+15
-11
tools/data_converter/scannet_data_utils.py
tools/data_converter/scannet_data_utils.py
+15
-11
tools/data_converter/sunrgbd_data_utils.py
tools/data_converter/sunrgbd_data_utils.py
+22
-17
tools/data_converter/waymo_converter.py
tools/data_converter/waymo_converter.py
+7
-6
tools/deployment/mmdet3d2torchserve.py
tools/deployment/mmdet3d2torchserve.py
+111
-0
No files found.
tests/test_utils/test_setup_env.py
0 → 100644
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
import
multiprocessing
as
mp
import
os
import
platform
import
cv2
from
mmcv
import
Config
from
mmdet3d.utils
import
setup_multi_processes
def
test_setup_multi_processes
():
# temp save system setting
sys_start_mehod
=
mp
.
get_start_method
(
allow_none
=
True
)
sys_cv_threads
=
cv2
.
getNumThreads
()
# pop and temp save system env vars
sys_omp_threads
=
os
.
environ
.
pop
(
'OMP_NUM_THREADS'
,
default
=
None
)
sys_mkl_threads
=
os
.
environ
.
pop
(
'MKL_NUM_THREADS'
,
default
=
None
)
# test config without setting env
config
=
dict
(
data
=
dict
(
workers_per_gpu
=
2
))
cfg
=
Config
(
config
)
setup_multi_processes
(
cfg
)
assert
os
.
getenv
(
'OMP_NUM_THREADS'
)
==
'1'
assert
os
.
getenv
(
'MKL_NUM_THREADS'
)
==
'1'
# when set to 0, the num threads will be 1
assert
cv2
.
getNumThreads
()
==
1
if
platform
.
system
()
!=
'Windows'
:
assert
mp
.
get_start_method
()
==
'fork'
# test num workers <= 1
os
.
environ
.
pop
(
'OMP_NUM_THREADS'
)
os
.
environ
.
pop
(
'MKL_NUM_THREADS'
)
config
=
dict
(
data
=
dict
(
workers_per_gpu
=
0
))
cfg
=
Config
(
config
)
setup_multi_processes
(
cfg
)
assert
'OMP_NUM_THREADS'
not
in
os
.
environ
assert
'MKL_NUM_THREADS'
not
in
os
.
environ
# test manually set env var
os
.
environ
[
'OMP_NUM_THREADS'
]
=
'4'
config
=
dict
(
data
=
dict
(
workers_per_gpu
=
2
))
cfg
=
Config
(
config
)
setup_multi_processes
(
cfg
)
assert
os
.
getenv
(
'OMP_NUM_THREADS'
)
==
'4'
# test manually set opencv threads and mp start method
config
=
dict
(
data
=
dict
(
workers_per_gpu
=
2
),
opencv_num_threads
=
4
,
mp_start_method
=
'spawn'
)
cfg
=
Config
(
config
)
setup_multi_processes
(
cfg
)
assert
cv2
.
getNumThreads
()
==
4
assert
mp
.
get_start_method
()
==
'spawn'
# revert setting to avoid affecting other programs
if
sys_start_mehod
:
mp
.
set_start_method
(
sys_start_mehod
,
force
=
True
)
cv2
.
setNumThreads
(
sys_cv_threads
)
if
sys_omp_threads
:
os
.
environ
[
'OMP_NUM_THREADS'
]
=
sys_omp_threads
else
:
os
.
environ
.
pop
(
'OMP_NUM_THREADS'
)
if
sys_mkl_threads
:
os
.
environ
[
'MKL_NUM_THREADS'
]
=
sys_mkl_threads
else
:
os
.
environ
.
pop
(
'MKL_NUM_THREADS'
)
tests/test_utils/test_utils.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
pytest
import
torch
import
torch
from
mmdet3d.core
import
draw_heatmap_gaussian
from
mmdet3d.core
import
array_converter
,
draw_heatmap_gaussian
,
points_img2cam
from
mmdet3d.core.bbox
import
CameraInstance3DBoxes
from
mmdet3d.models.utils
import
(
filter_outside_objs
,
get_edge_indices
,
get_keypoints
,
handle_proj_objs
)
def
test_gaussian
():
def
test_gaussian
():
...
@@ -10,3 +15,274 @@ def test_gaussian():
...
@@ -10,3 +15,274 @@ def test_gaussian():
radius
=
2
radius
=
2
draw_heatmap_gaussian
(
heatmap
,
ct_int
,
radius
)
draw_heatmap_gaussian
(
heatmap
,
ct_int
,
radius
)
assert
torch
.
isclose
(
torch
.
sum
(
heatmap
),
torch
.
tensor
(
4.3505
),
atol
=
1e-3
)
assert
torch
.
isclose
(
torch
.
sum
(
heatmap
),
torch
.
tensor
(
4.3505
),
atol
=
1e-3
)
def
test_array_converter
():
# to torch
@
array_converter
(
to_torch
=
True
,
apply_to
=
(
'array_a'
,
'array_b'
))
def
test_func_1
(
array_a
,
array_b
,
container
):
container
.
append
(
array_a
)
container
.
append
(
array_b
)
return
array_a
.
clone
(),
array_b
.
clone
()
np_array_a
=
np
.
array
([
0.0
])
np_array_b
=
np
.
array
([
0.0
])
container
=
[]
new_array_a
,
new_array_b
=
test_func_1
(
np_array_a
,
np_array_b
,
container
)
assert
isinstance
(
new_array_a
,
np
.
ndarray
)
assert
isinstance
(
new_array_b
,
np
.
ndarray
)
assert
isinstance
(
container
[
0
],
torch
.
Tensor
)
assert
isinstance
(
container
[
1
],
torch
.
Tensor
)
# one to torch and one not
@
array_converter
(
to_torch
=
True
,
apply_to
=
(
'array_a'
,
))
def
test_func_2
(
array_a
,
array_b
):
return
torch
.
cat
([
array_a
,
array_b
])
with
pytest
.
raises
(
TypeError
):
_
=
test_func_2
(
np_array_a
,
np_array_b
)
# wrong template_arg_name_
@
array_converter
(
to_torch
=
True
,
apply_to
=
(
'array_a'
,
),
template_arg_name_
=
'array_c'
)
def
test_func_3
(
array_a
,
array_b
):
return
torch
.
cat
([
array_a
,
array_b
])
with
pytest
.
raises
(
ValueError
):
_
=
test_func_3
(
np_array_a
,
np_array_b
)
# wrong apply_to
@
array_converter
(
to_torch
=
True
,
apply_to
=
(
'array_a'
,
'array_c'
))
def
test_func_4
(
array_a
,
array_b
):
return
torch
.
cat
([
array_a
,
array_b
])
with
pytest
.
raises
(
ValueError
):
_
=
test_func_4
(
np_array_a
,
np_array_b
)
# to numpy
@
array_converter
(
to_torch
=
False
,
apply_to
=
(
'array_a'
,
'array_b'
))
def
test_func_5
(
array_a
,
array_b
,
container
):
container
.
append
(
array_a
)
container
.
append
(
array_b
)
return
array_a
.
copy
(),
array_b
.
copy
()
pt_array_a
=
torch
.
tensor
([
0.0
])
pt_array_b
=
torch
.
tensor
([
0.0
])
container
=
[]
new_array_a
,
new_array_b
=
test_func_5
(
pt_array_a
,
pt_array_b
,
container
)
assert
isinstance
(
container
[
0
],
np
.
ndarray
)
assert
isinstance
(
container
[
1
],
np
.
ndarray
)
assert
isinstance
(
new_array_a
,
torch
.
Tensor
)
assert
isinstance
(
new_array_b
,
torch
.
Tensor
)
# apply_to = None
@
array_converter
(
to_torch
=
False
)
def
test_func_6
(
array_a
,
array_b
,
container
):
container
.
append
(
array_a
)
container
.
append
(
array_b
)
return
array_a
.
clone
(),
array_b
.
clone
()
container
=
[]
new_array_a
,
new_array_b
=
test_func_6
(
pt_array_a
,
pt_array_b
,
container
)
assert
isinstance
(
container
[
0
],
torch
.
Tensor
)
assert
isinstance
(
container
[
1
],
torch
.
Tensor
)
assert
isinstance
(
new_array_a
,
torch
.
Tensor
)
assert
isinstance
(
new_array_b
,
torch
.
Tensor
)
# with default arg
@
array_converter
(
to_torch
=
True
,
apply_to
=
(
'array_a'
,
'array_b'
))
def
test_func_7
(
array_a
,
container
,
array_b
=
np
.
array
([
2.
])):
container
.
append
(
array_a
)
container
.
append
(
array_b
)
return
array_a
.
clone
(),
array_b
.
clone
()
container
=
[]
new_array_a
,
new_array_b
=
test_func_7
(
np_array_a
,
container
)
assert
isinstance
(
container
[
0
],
torch
.
Tensor
)
assert
isinstance
(
container
[
1
],
torch
.
Tensor
)
assert
isinstance
(
new_array_a
,
np
.
ndarray
)
assert
isinstance
(
new_array_b
,
np
.
ndarray
)
assert
np
.
allclose
(
new_array_b
,
np
.
array
([
2.
]),
1e-3
)
# override default arg
container
=
[]
new_array_a
,
new_array_b
=
test_func_7
(
np_array_a
,
container
,
np
.
array
([
4.
]))
assert
isinstance
(
container
[
0
],
torch
.
Tensor
)
assert
isinstance
(
container
[
1
],
torch
.
Tensor
)
assert
isinstance
(
new_array_a
,
np
.
ndarray
)
assert
np
.
allclose
(
new_array_b
,
np
.
array
([
4.
]),
1e-3
)
# list arg
@
array_converter
(
to_torch
=
True
,
apply_to
=
(
'array_a'
,
'array_b'
))
def
test_func_8
(
container
,
array_a
,
array_b
=
[
2.
]):
container
.
append
(
array_a
)
container
.
append
(
array_b
)
return
array_a
.
clone
(),
array_b
.
clone
()
container
=
[]
new_array_a
,
new_array_b
=
test_func_8
(
container
,
[
3.
])
assert
isinstance
(
container
[
0
],
torch
.
Tensor
)
assert
isinstance
(
container
[
1
],
torch
.
Tensor
)
assert
np
.
allclose
(
new_array_a
,
np
.
array
([
3.
]),
1e-3
)
assert
np
.
allclose
(
new_array_b
,
np
.
array
([
2.
]),
1e-3
)
# number arg
@
array_converter
(
to_torch
=
True
,
apply_to
=
(
'array_a'
,
'array_b'
))
def
test_func_9
(
container
,
array_a
,
array_b
=
1
):
container
.
append
(
array_a
)
container
.
append
(
array_b
)
return
array_a
.
clone
(),
array_b
.
clone
()
container
=
[]
new_array_a
,
new_array_b
=
test_func_9
(
container
,
np_array_a
)
assert
isinstance
(
container
[
0
],
torch
.
FloatTensor
)
assert
isinstance
(
container
[
1
],
torch
.
FloatTensor
)
assert
np
.
allclose
(
new_array_a
,
np_array_a
,
1e-3
)
assert
np
.
allclose
(
new_array_b
,
np
.
array
(
1.0
),
1e-3
)
# feed kwargs
container
=
[]
kwargs
=
{
'array_a'
:
[
5.
],
'array_b'
:
[
6.
]}
new_array_a
,
new_array_b
=
test_func_8
(
container
,
**
kwargs
)
assert
isinstance
(
container
[
0
],
torch
.
Tensor
)
assert
isinstance
(
container
[
1
],
torch
.
Tensor
)
assert
np
.
allclose
(
new_array_a
,
np
.
array
([
5.
]),
1e-3
)
assert
np
.
allclose
(
new_array_b
,
np
.
array
([
6.
]),
1e-3
)
# feed args and kwargs
container
=
[]
kwargs
=
{
'array_b'
:
[
7.
]}
args
=
(
container
,
[
8.
])
new_array_a
,
new_array_b
=
test_func_8
(
*
args
,
**
kwargs
)
assert
isinstance
(
container
[
0
],
torch
.
Tensor
)
assert
isinstance
(
container
[
1
],
torch
.
Tensor
)
assert
np
.
allclose
(
new_array_a
,
np
.
array
([
8.
]),
1e-3
)
assert
np
.
allclose
(
new_array_b
,
np
.
array
([
7.
]),
1e-3
)
# wrong template arg type
with
pytest
.
raises
(
TypeError
):
new_array_a
,
new_array_b
=
test_func_9
(
container
,
3
+
4j
)
with
pytest
.
raises
(
TypeError
):
new_array_a
,
new_array_b
=
test_func_9
(
container
,
{})
# invalid template arg list
with
pytest
.
raises
(
TypeError
):
new_array_a
,
new_array_b
=
test_func_9
(
container
,
[
True
,
np
.
array
([
3.0
])])
def
test_points_img2cam
():
points
=
torch
.
tensor
([[
0.5764
,
0.9109
,
0.7576
],
[
0.6656
,
0.5498
,
0.9813
]])
cam2img
=
torch
.
tensor
([[
700.
,
0.
,
450.
,
0.
],
[
0.
,
700.
,
200.
,
0.
],
[
0.
,
0.
,
1.
,
0.
]])
xyzs
=
points_img2cam
(
points
,
cam2img
)
expected_xyzs
=
torch
.
tensor
([[
-
0.4864
,
-
0.2155
,
0.7576
],
[
-
0.6299
,
-
0.2796
,
0.9813
]])
assert
torch
.
allclose
(
xyzs
,
expected_xyzs
,
atol
=
1e-3
)
def
test_generate_edge_indices
():
input_metas
=
[
dict
(
img_shape
=
(
110
,
110
),
pad_shape
=
(
128
,
128
)),
dict
(
img_shape
=
(
98
,
110
),
pad_shape
=
(
128
,
128
))
]
downsample_ratio
=
4
edge_indices_list
=
get_edge_indices
(
input_metas
,
downsample_ratio
)
assert
edge_indices_list
[
0
].
shape
[
0
]
==
108
assert
edge_indices_list
[
1
].
shape
[
0
]
==
102
def
test_truncation_hanlde
():
centers2d_list
=
[
torch
.
tensor
([[
-
99.86
,
199.45
],
[
499.50
,
399.20
],
[
201.20
,
99.86
]])
]
gt_bboxes_list
=
[
torch
.
tensor
([[
0.25
,
99.8
,
99.8
,
199.6
],
[
300.2
,
250.1
,
399.8
,
299.6
],
[
100.2
,
20.1
,
300.8
,
180.7
]])
]
img_metas
=
[
dict
(
img_shape
=
[
300
,
400
])]
centers2d_target_list
,
offsets2d_list
,
trunc_mask_list
=
\
handle_proj_objs
(
centers2d_list
,
gt_bboxes_list
,
img_metas
)
centers2d_target
=
torch
.
tensor
([[
0.
,
166.30435501
],
[
379.03437877
,
299.
],
[
201.2
,
99.86
]])
offsets2d
=
torch
.
tensor
([[
-
99.86
,
33.45
],
[
120.5
,
100.2
],
[
0.2
,
-
0.14
]])
trunc_mask
=
torch
.
tensor
([
True
,
True
,
False
])
assert
torch
.
allclose
(
centers2d_target_list
[
0
],
centers2d_target
)
assert
torch
.
allclose
(
offsets2d_list
[
0
],
offsets2d
,
atol
=
1e-4
)
assert
torch
.
all
(
trunc_mask_list
[
0
]
==
trunc_mask
)
assert
torch
.
allclose
(
centers2d_target_list
[
0
].
round
().
int
()
+
offsets2d_list
[
0
],
centers2d_list
[
0
])
def
test_filter_outside_objs
():
centers2d_list
=
[
torch
.
tensor
([[
-
99.86
,
199.45
],
[
499.50
,
399.20
],
[
201.20
,
99.86
]]),
torch
.
tensor
([[
-
47.86
,
199.45
],
[
410.50
,
399.20
],
[
401.20
,
349.86
]])
]
gt_bboxes_list
=
[
torch
.
rand
([
3
,
4
],
dtype
=
torch
.
float32
),
torch
.
rand
([
3
,
4
],
dtype
=
torch
.
float32
)
]
gt_bboxes_3d_list
=
[
CameraInstance3DBoxes
(
torch
.
rand
([
3
,
7
]),
box_dim
=
7
),
CameraInstance3DBoxes
(
torch
.
rand
([
3
,
7
]),
box_dim
=
7
)
]
gt_labels_list
=
[
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
2
,
0
,
0
])]
gt_labels_3d_list
=
[
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
2
,
0
,
0
])]
img_metas
=
[
dict
(
img_shape
=
[
300
,
400
]),
dict
(
img_shape
=
[
500
,
450
])]
filter_outside_objs
(
gt_bboxes_list
,
gt_labels_list
,
gt_bboxes_3d_list
,
gt_labels_3d_list
,
centers2d_list
,
img_metas
)
assert
len
(
centers2d_list
[
0
])
==
len
(
gt_bboxes_3d_list
[
0
])
==
\
len
(
gt_bboxes_list
[
0
])
==
len
(
gt_labels_3d_list
[
0
])
==
\
len
(
gt_labels_list
[
0
])
==
1
assert
len
(
centers2d_list
[
1
])
==
len
(
gt_bboxes_3d_list
[
1
])
==
\
len
(
gt_bboxes_list
[
1
])
==
len
(
gt_labels_3d_list
[
1
])
==
\
len
(
gt_labels_list
[
1
])
==
2
def
test_generate_keypoints
():
centers2d_list
=
[
torch
.
tensor
([[
-
99.86
,
199.45
],
[
499.50
,
399.20
],
[
201.20
,
99.86
]]),
torch
.
tensor
([[
-
47.86
,
199.45
],
[
410.50
,
399.20
],
[
401.20
,
349.86
]])
]
gt_bboxes_3d_list
=
[
CameraInstance3DBoxes
(
torch
.
rand
([
3
,
7
])),
CameraInstance3DBoxes
(
torch
.
rand
([
3
,
7
]))
]
img_metas
=
[
dict
(
cam2img
=
[[
1260.8474446004698
,
0.0
,
807.968244525554
,
40.1111
],
[
0.0
,
1260.8474446004698
,
495.3344268742088
,
2.34422
],
[
0.0
,
0.0
,
1.0
,
0.00333333
],
[
0.0
,
0.0
,
0.0
,
1.0
]],
img_shape
=
(
300
,
400
))
for
i
in
range
(
2
)
]
keypoints2d_list
,
keypoints_depth_mask_list
=
\
get_keypoints
(
gt_bboxes_3d_list
,
centers2d_list
,
img_metas
)
assert
keypoints2d_list
[
0
].
shape
==
(
3
,
10
,
3
)
assert
keypoints_depth_mask_list
[
0
].
shape
==
(
3
,
3
)
tools/analysis_tools/analyze_logs.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
argparse
import
json
import
json
from
collections
import
defaultdict
import
numpy
as
np
import
numpy
as
np
import
seaborn
as
sns
import
seaborn
as
sns
from
collections
import
defaultdict
from
matplotlib
import
pyplot
as
plt
from
matplotlib
import
pyplot
as
plt
...
...
tools/analysis_tools/benchmark.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
argparse
import
time
import
time
import
torch
import
torch
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.parallel
import
MMDataParallel
from
mmcv.parallel
import
MMDataParallel
...
...
tools/analysis_tools/get_flops.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
argparse
import
torch
import
torch
from
mmcv
import
Config
,
DictAction
from
mmcv
import
Config
,
DictAction
...
...
tools/create_data.py
View file @
32a4328b
...
@@ -61,7 +61,8 @@ def nuscenes_data_prep(root_path,
...
@@ -61,7 +61,8 @@ def nuscenes_data_prep(root_path,
version (str): Dataset version.
version (str): Dataset version.
dataset_name (str): The dataset class name.
dataset_name (str): The dataset class name.
out_dir (str): Output directory of the groundtruth database info.
out_dir (str): Output directory of the groundtruth database info.
max_sweeps (int): Number of input consecutive frames. Default: 10
max_sweeps (int, optional): Number of input consecutive frames.
Default: 10
"""
"""
nuscenes_converter
.
create_nuscenes_infos
(
nuscenes_converter
.
create_nuscenes_infos
(
root_path
,
info_prefix
,
version
=
version
,
max_sweeps
=
max_sweeps
)
root_path
,
info_prefix
,
version
=
version
,
max_sweeps
=
max_sweeps
)
...
@@ -152,8 +153,9 @@ def waymo_data_prep(root_path,
...
@@ -152,8 +153,9 @@ def waymo_data_prep(root_path,
info_prefix (str): The prefix of info filenames.
info_prefix (str): The prefix of info filenames.
out_dir (str): Output directory of the generated info file.
out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used.
workers (int): Number of threads to be used.
max_sweeps (int): Number of input consecutive frames. Default: 5
\
max_sweeps (int, optional): Number of input consecutive frames.
Here we store pose information of these frames for later use.
Default: 5. Here we store pose information of these frames
for later use.
"""
"""
from
tools.data_converter
import
waymo_converter
as
waymo
from
tools.data_converter
import
waymo_converter
as
waymo
...
@@ -206,7 +208,7 @@ parser.add_argument(
...
@@ -206,7 +208,7 @@ parser.add_argument(
'--out-dir'
,
'--out-dir'
,
type
=
str
,
type
=
str
,
default
=
'./data/kitti'
,
default
=
'./data/kitti'
,
required
=
'
False
'
,
required
=
False
,
help
=
'name of info pkl'
)
help
=
'name of info pkl'
)
parser
.
add_argument
(
'--extra-tag'
,
type
=
str
,
default
=
'kitti'
)
parser
.
add_argument
(
'--extra-tag'
,
type
=
str
,
default
=
'kitti'
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
tools/create_data.sh
View file @
32a4328b
...
@@ -5,8 +5,7 @@ export PYTHONPATH=`pwd`:$PYTHONPATH
...
@@ -5,8 +5,7 @@ export PYTHONPATH=`pwd`:$PYTHONPATH
PARTITION
=
$1
PARTITION
=
$1
JOB_NAME
=
$2
JOB_NAME
=
$2
CONFIG
=
$3
DATASET
=
$3
WORK_DIR
=
$4
GPUS
=
${
GPUS
:-
1
}
GPUS
=
${
GPUS
:-
1
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
1
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
1
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
...
@@ -19,7 +18,7 @@ srun -p ${PARTITION} \
...
@@ -19,7 +18,7 @@ srun -p ${PARTITION} \
--ntasks-per-node
=
${
GPUS_PER_NODE
}
\
--ntasks-per-node
=
${
GPUS_PER_NODE
}
\
--kill-on-bad-exit
=
1
\
--kill-on-bad-exit
=
1
\
${
SRUN_ARGS
}
\
${
SRUN_ARGS
}
\
python
-u
tools/create_data.py
kitti
\
python
-u
tools/create_data.py
${
DATASET
}
\
--root-path
./data/
kitti
\
--root-path
./data/
${
DATASET
}
\
--out-dir
./data/
kitti
\
--out-dir
./data/
${
DATASET
}
\
--extra-tag
kitti
--extra-tag
${
DATASET
}
tools/data_converter/create_gt_database.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
pickle
from
os
import
path
as
osp
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
pickle
from
mmcv
import
track_iter_progress
from
mmcv
import
track_iter_progress
from
mmcv.ops
import
roi_align
from
mmcv.ops
import
roi_align
from
os
import
path
as
osp
from
pycocotools
import
mask
as
maskUtils
from
pycocotools
import
mask
as
maskUtils
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
...
@@ -126,19 +127,19 @@ def create_groundtruth_database(dataset_class_name,
...
@@ -126,19 +127,19 @@ def create_groundtruth_database(dataset_class_name,
dataset_class_name (str): Name of the input dataset.
dataset_class_name (str): Name of the input dataset.
data_path (str): Path of the data.
data_path (str): Path of the data.
info_prefix (str): Prefix of the info file.
info_prefix (str): Prefix of the info file.
info_path (str): Path of the info file.
info_path (str
, optional
): Path of the info file.
Default: None.
Default: None.
mask_anno_path (str): Path of the mask_anno.
mask_anno_path (str
, optional
): Path of the mask_anno.
Default: None.
Default: None.
used_classes (list[str]): Classes have been used.
used_classes (list[str]
, optional
): Classes have been used.
Default: None.
Default: None.
database_save_path (str): Path to save database.
database_save_path (str
, optional
): Path to save database.
Default: None.
Default: None.
db_info_save_path (str): Path to save db_info.
db_info_save_path (str
, optional
): Path to save db_info.
Default: None.
Default: None.
relative_path (bool): Whether to use relative path.
relative_path (bool
, optional
): Whether to use relative path.
Default: True.
Default: True.
with_mask (bool): Whether to use mask.
with_mask (bool
, optional
): Whether to use mask.
Default: False.
Default: False.
"""
"""
print
(
f
'Create GT Database of
{
dataset_class_name
}
'
)
print
(
f
'Create GT Database of
{
dataset_class_name
}
'
)
...
...
tools/data_converter/indoor_converter.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
os
from
tools.data_converter.s3dis_data_utils
import
S3DISData
,
S3DISSegData
from
tools.data_converter.s3dis_data_utils
import
S3DISData
,
S3DISSegData
from
tools.data_converter.scannet_data_utils
import
ScanNetData
,
ScanNetSegData
from
tools.data_converter.scannet_data_utils
import
ScanNetData
,
ScanNetSegData
...
@@ -19,10 +20,11 @@ def create_indoor_info_file(data_path,
...
@@ -19,10 +20,11 @@ def create_indoor_info_file(data_path,
Args:
Args:
data_path (str): Path of the data.
data_path (str): Path of the data.
pkl_prefix (str): Prefix of the pkl to be saved. Default: 'sunrgbd'.
pkl_prefix (str, optional): Prefix of the pkl to be saved.
save_path (str): Path of the pkl to be saved. Default: None.
Default: 'sunrgbd'.
use_v1 (bool): Whether to use v1. Default: False.
save_path (str, optional): Path of the pkl to be saved. Default: None.
workers (int): Number of threads to be used. Default: 4.
use_v1 (bool, optional): Whether to use v1. Default: False.
workers (int, optional): Number of threads to be used. Default: 4.
"""
"""
assert
os
.
path
.
exists
(
data_path
)
assert
os
.
path
.
exists
(
data_path
)
assert
pkl_prefix
in
[
'sunrgbd'
,
'scannet'
,
's3dis'
],
\
assert
pkl_prefix
in
[
'sunrgbd'
,
'scannet'
,
's3dis'
],
\
...
...
tools/data_converter/kitti_converter.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
collections
import
OrderedDict
from
pathlib
import
Path
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
from
collections
import
OrderedDict
from
nuscenes.utils.geometry_utils
import
view_points
from
nuscenes.utils.geometry_utils
import
view_points
from
pathlib
import
Path
from
mmdet3d.core.bbox
import
box_np_ops
from
mmdet3d.core.bbox
import
box_np_ops
,
points_cam2img
from
.kitti_data_utils
import
get_kitti_image_info
,
get_waymo_image_info
from
.kitti_data_utils
import
get_kitti_image_info
,
get_waymo_image_info
from
.nuscenes_converter
import
post_process_coords
from
.nuscenes_converter
import
post_process_coords
...
@@ -94,9 +95,12 @@ def create_kitti_info_file(data_path,
...
@@ -94,9 +95,12 @@ def create_kitti_info_file(data_path,
Args:
Args:
data_path (str): Path of the data root.
data_path (str): Path of the data root.
pkl_prefix (str): Prefix of the info file to be generated.
pkl_prefix (str, optional): Prefix of the info file to be generated.
save_path (str): Path to save the info file.
Default: 'kitti'.
relative_path (bool): Whether to use relative path.
save_path (str, optional): Path to save the info file.
Default: None.
relative_path (bool, optional): Whether to use relative path.
Default: True.
"""
"""
imageset_folder
=
Path
(
data_path
)
/
'ImageSets'
imageset_folder
=
Path
(
data_path
)
/
'ImageSets'
train_img_ids
=
_read_imageset_file
(
str
(
imageset_folder
/
'train.txt'
))
train_img_ids
=
_read_imageset_file
(
str
(
imageset_folder
/
'train.txt'
))
...
@@ -113,6 +117,7 @@ def create_kitti_info_file(data_path,
...
@@ -113,6 +117,7 @@ def create_kitti_info_file(data_path,
training
=
True
,
training
=
True
,
velodyne
=
True
,
velodyne
=
True
,
calib
=
True
,
calib
=
True
,
with_plane
=
True
,
image_ids
=
train_img_ids
,
image_ids
=
train_img_ids
,
relative_path
=
relative_path
)
relative_path
=
relative_path
)
_calculate_num_points_in_gt
(
data_path
,
kitti_infos_train
,
relative_path
)
_calculate_num_points_in_gt
(
data_path
,
kitti_infos_train
,
relative_path
)
...
@@ -124,6 +129,7 @@ def create_kitti_info_file(data_path,
...
@@ -124,6 +129,7 @@ def create_kitti_info_file(data_path,
training
=
True
,
training
=
True
,
velodyne
=
True
,
velodyne
=
True
,
calib
=
True
,
calib
=
True
,
with_plane
=
True
,
image_ids
=
val_img_ids
,
image_ids
=
val_img_ids
,
relative_path
=
relative_path
)
relative_path
=
relative_path
)
_calculate_num_points_in_gt
(
data_path
,
kitti_infos_val
,
relative_path
)
_calculate_num_points_in_gt
(
data_path
,
kitti_infos_val
,
relative_path
)
...
@@ -140,6 +146,7 @@ def create_kitti_info_file(data_path,
...
@@ -140,6 +146,7 @@ def create_kitti_info_file(data_path,
label_info
=
False
,
label_info
=
False
,
velodyne
=
True
,
velodyne
=
True
,
calib
=
True
,
calib
=
True
,
with_plane
=
False
,
image_ids
=
test_img_ids
,
image_ids
=
test_img_ids
,
relative_path
=
relative_path
)
relative_path
=
relative_path
)
filename
=
save_path
/
f
'
{
pkl_prefix
}
_infos_test.pkl'
filename
=
save_path
/
f
'
{
pkl_prefix
}
_infos_test.pkl'
...
@@ -158,10 +165,14 @@ def create_waymo_info_file(data_path,
...
@@ -158,10 +165,14 @@ def create_waymo_info_file(data_path,
Args:
Args:
data_path (str): Path of the data root.
data_path (str): Path of the data root.
pkl_prefix (str): Prefix of the info file to be generated.
pkl_prefix (str, optional): Prefix of the info file to be generated.
save_path (str | None): Path to save the info file.
Default: 'waymo'.
relative_path (bool): Whether to use relative path.
save_path (str, optional): Path to save the info file.
max_sweeps (int): Max sweeps before the detection frame to be used.
Default: None.
relative_path (bool, optional): Whether to use relative path.
Default: True.
max_sweeps (int, optional): Max sweeps before the detection frame
to be used. Default: 5.
"""
"""
imageset_folder
=
Path
(
data_path
)
/
'ImageSets'
imageset_folder
=
Path
(
data_path
)
/
'ImageSets'
train_img_ids
=
_read_imageset_file
(
str
(
imageset_folder
/
'train.txt'
))
train_img_ids
=
_read_imageset_file
(
str
(
imageset_folder
/
'train.txt'
))
...
@@ -238,11 +249,13 @@ def _create_reduced_point_cloud(data_path,
...
@@ -238,11 +249,13 @@ def _create_reduced_point_cloud(data_path,
Args:
Args:
data_path (str): Path of original data.
data_path (str): Path of original data.
info_path (str): Path of data info.
info_path (str): Path of data info.
save_path (str | None): Path to save reduced point cloud data.
save_path (str, optional): Path to save reduced point cloud
Default: None.
data. Default: None.
back (bool): Whether to flip the points to back.
back (bool, optional): Whether to flip the points to back.
num_features (int): Number of point features. Default: 4.
Default: False.
front_camera_id (int): The referenced/front camera ID. Default: 2.
num_features (int, optional): Number of point features. Default: 4.
front_camera_id (int, optional): The referenced/front camera ID.
Default: 2.
"""
"""
kitti_infos
=
mmcv
.
load
(
info_path
)
kitti_infos
=
mmcv
.
load
(
info_path
)
...
@@ -298,14 +311,16 @@ def create_reduced_point_cloud(data_path,
...
@@ -298,14 +311,16 @@ def create_reduced_point_cloud(data_path,
Args:
Args:
data_path (str): Path of original data.
data_path (str): Path of original data.
pkl_prefix (str): Prefix of info files.
pkl_prefix (str): Prefix of info files.
train_info_path (str | None): Path of training set info.
train_info_path (str, optional): Path of training set info.
Default: None.
val_info_path (str, optional): Path of validation set info.
Default: None.
Default: None.
val
_info_path (str
| None
): Path of
validation
set info.
test
_info_path (str
, optional
): Path of
test
set info.
Default: None.
Default: None.
test_info
_path (str
| None): Path of test set info
.
save
_path (str
, optional): Path to save reduced point cloud data
.
Default: None.
Default: None.
save_path (str | None): Path to save reduced point cloud data
.
with_back (bool, optional): Whether to flip the points to back
.
with_back (bool): Whether to flip the points to back
.
Default: False
.
"""
"""
if
train_info_path
is
None
:
if
train_info_path
is
None
:
train_info_path
=
Path
(
data_path
)
/
f
'
{
pkl_prefix
}
_infos_train.pkl'
train_info_path
=
Path
(
data_path
)
/
f
'
{
pkl_prefix
}
_infos_train.pkl'
...
@@ -335,7 +350,8 @@ def export_2d_annotation(root_path, info_path, mono3d=True):
...
@@ -335,7 +350,8 @@ def export_2d_annotation(root_path, info_path, mono3d=True):
Args:
Args:
root_path (str): Root path of the raw data.
root_path (str): Root path of the raw data.
info_path (str): Path of the info file.
info_path (str): Path of the info file.
mono3d (bool): Whether to export mono3d annotation. Default: True.
mono3d (bool, optional): Whether to export mono3d annotation.
Default: True.
"""
"""
# get bbox annotations for camera
# get bbox annotations for camera
kitti_infos
=
mmcv
.
load
(
info_path
)
kitti_infos
=
mmcv
.
load
(
info_path
)
...
@@ -381,8 +397,8 @@ def get_2d_boxes(info, occluded, mono3d=True):
...
@@ -381,8 +397,8 @@ def get_2d_boxes(info, occluded, mono3d=True):
Args:
Args:
info: Information of the given sample data.
info: Information of the given sample data.
occluded: Integer (0, 1, 2, 3) indicating occlusion state:
\
occluded: Integer (0, 1, 2, 3) indicating occlusion state:
0 = fully visible, 1 = partly occluded, 2 = largely occluded,
\
0 = fully visible, 1 = partly occluded, 2 = largely occluded,
3 = unknown, -1 = DontCare
3 = unknown, -1 = DontCare
mono3d (bool): Whether to get boxes with mono3d annotation.
mono3d (bool): Whether to get boxes with mono3d annotation.
...
@@ -471,7 +487,7 @@ def get_2d_boxes(info, occluded, mono3d=True):
...
@@ -471,7 +487,7 @@ def get_2d_boxes(info, occluded, mono3d=True):
repro_rec
[
'velo_cam3d'
]
=
-
1
# no velocity in KITTI
repro_rec
[
'velo_cam3d'
]
=
-
1
# no velocity in KITTI
center3d
=
np
.
array
(
loc
).
reshape
([
1
,
3
])
center3d
=
np
.
array
(
loc
).
reshape
([
1
,
3
])
center2d
=
box_np_ops
.
points_cam2img
(
center2d
=
points_cam2img
(
center3d
,
camera_intrinsic
,
with_depth
=
True
)
center3d
,
camera_intrinsic
,
with_depth
=
True
)
repro_rec
[
'center2d'
]
=
center2d
.
squeeze
().
tolist
()
repro_rec
[
'center2d'
]
=
center2d
.
squeeze
().
tolist
()
# normalized center2D + depth
# normalized center2D + depth
...
@@ -488,7 +504,7 @@ def get_2d_boxes(info, occluded, mono3d=True):
...
@@ -488,7 +504,7 @@ def get_2d_boxes(info, occluded, mono3d=True):
def
generate_record
(
ann_rec
,
x1
,
y1
,
x2
,
y2
,
sample_data_token
,
filename
):
def
generate_record
(
ann_rec
,
x1
,
y1
,
x2
,
y2
,
sample_data_token
,
filename
):
"""Generate one 2D annotation record given various information
s
on top of
"""Generate one 2D annotation record given various information on top of
the 2D bounding box coordinates.
the 2D bounding box coordinates.
Args:
Args:
...
@@ -503,12 +519,12 @@ def generate_record(ann_rec, x1, y1, x2, y2, sample_data_token, filename):
...
@@ -503,12 +519,12 @@ def generate_record(ann_rec, x1, y1, x2, y2, sample_data_token, filename):
Returns:
Returns:
dict: A sample 2D annotation record.
dict: A sample 2D annotation record.
- file_name (str): f
l
ie name
- file_name (str): fi
l
e name
- image_id (str): sample data token
- image_id (str): sample data token
- area (float): 2d box area
- area (float): 2d box area
- category_name (str): category name
- category_name (str): category name
- category_id (int): category id
- category_id (int): category id
- bbox (list[float]): left x, top y,
dx, dy
of 2d box
- bbox (list[float]): left x, top y,
x_size, y_size
of 2d box
- iscrowd (int): whether the area is crowd
- iscrowd (int): whether the area is crowd
"""
"""
repro_rec
=
OrderedDict
()
repro_rec
=
OrderedDict
()
...
...
tools/data_converter/kitti_data_utils.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
concurrent
import
futures
as
futures
from
concurrent
import
futures
as
futures
from
os
import
path
as
osp
from
os
import
path
as
osp
from
pathlib
import
Path
from
pathlib
import
Path
import
mmcv
import
numpy
as
np
from
skimage
import
io
from
skimage
import
io
...
@@ -59,6 +61,17 @@ def get_label_path(idx,
...
@@ -59,6 +61,17 @@ def get_label_path(idx,
relative_path
,
exist_check
,
use_prefix_id
)
relative_path
,
exist_check
,
use_prefix_id
)
def
get_plane_path
(
idx
,
prefix
,
training
=
True
,
relative_path
=
True
,
exist_check
=
True
,
info_type
=
'planes'
,
use_prefix_id
=
False
):
return
get_kitti_info_path
(
idx
,
prefix
,
info_type
,
'.txt'
,
training
,
relative_path
,
exist_check
,
use_prefix_id
)
def
get_velodyne_path
(
idx
,
def
get_velodyne_path
(
idx
,
prefix
,
prefix
,
training
=
True
,
training
=
True
,
...
@@ -143,6 +156,7 @@ def get_kitti_image_info(path,
...
@@ -143,6 +156,7 @@ def get_kitti_image_info(path,
label_info
=
True
,
label_info
=
True
,
velodyne
=
False
,
velodyne
=
False
,
calib
=
False
,
calib
=
False
,
with_plane
=
False
,
image_ids
=
7481
,
image_ids
=
7481
,
extend_matrix
=
True
,
extend_matrix
=
True
,
num_worker
=
8
,
num_worker
=
8
,
...
@@ -251,6 +265,13 @@ def get_kitti_image_info(path,
...
@@ -251,6 +265,13 @@ def get_kitti_image_info(path,
calib_info
[
'Tr_imu_to_velo'
]
=
Tr_imu_to_velo
calib_info
[
'Tr_imu_to_velo'
]
=
Tr_imu_to_velo
info
[
'calib'
]
=
calib_info
info
[
'calib'
]
=
calib_info
if
with_plane
:
plane_path
=
get_plane_path
(
idx
,
path
,
training
,
relative_path
)
if
relative_path
:
plane_path
=
str
(
root_path
/
plane_path
)
lines
=
mmcv
.
list_from_file
(
plane_path
)
info
[
'plane'
]
=
np
.
array
([
float
(
i
)
for
i
in
lines
[
3
].
split
()])
if
annotations
is
not
None
:
if
annotations
is
not
None
:
info
[
'annos'
]
=
annotations
info
[
'annos'
]
=
annotations
add_difficulty_to_annos
(
info
)
add_difficulty_to_annos
(
info
)
...
...
tools/data_converter/lyft_converter.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
import
os
import
os
from
logging
import
warning
from
logging
import
warning
from
lyft_dataset_sdk.lyftdataset
import
LyftDataset
as
Lyft
from
os
import
path
as
osp
from
os
import
path
as
osp
import
mmcv
import
numpy
as
np
from
lyft_dataset_sdk.lyftdataset
import
LyftDataset
as
Lyft
from
pyquaternion
import
Quaternion
from
pyquaternion
import
Quaternion
from
mmdet3d.datasets
import
LyftDataset
from
mmdet3d.datasets
import
LyftDataset
...
@@ -26,10 +27,10 @@ def create_lyft_infos(root_path,
...
@@ -26,10 +27,10 @@ def create_lyft_infos(root_path,
Args:
Args:
root_path (str): Path of the data root.
root_path (str): Path of the data root.
info_prefix (str): Prefix of the info file to be generated.
info_prefix (str): Prefix of the info file to be generated.
version (str): Version of the data.
version (str
, optional
): Version of the data.
Default: 'v1.01-train'
Default: 'v1.01-train'
.
max_sweeps (int): Max number of sweeps.
max_sweeps (int
, optional
): Max number of sweeps.
Default: 10
Default: 10
.
"""
"""
lyft
=
Lyft
(
lyft
=
Lyft
(
data_path
=
osp
.
join
(
root_path
,
version
),
data_path
=
osp
.
join
(
root_path
,
version
),
...
@@ -101,9 +102,9 @@ def _fill_trainval_infos(lyft,
...
@@ -101,9 +102,9 @@ def _fill_trainval_infos(lyft,
lyft (:obj:`LyftDataset`): Dataset class in the Lyft dataset.
lyft (:obj:`LyftDataset`): Dataset class in the Lyft dataset.
train_scenes (list[str]): Basic information of training scenes.
train_scenes (list[str]): Basic information of training scenes.
val_scenes (list[str]): Basic information of validation scenes.
val_scenes (list[str]): Basic information of validation scenes.
test (bool): Whether use the test mode. In the test mode, no
test (bool
, optional
): Whether use the test mode. In the test mode, no
annotations can be accessed. Default: False.
annotations can be accessed. Default: False.
max_sweeps (int): Max number of sweeps. Default: 10.
max_sweeps (int
, optional
): Max number of sweeps. Default: 10.
Returns:
Returns:
tuple[list[dict]]: Information of training set and
tuple[list[dict]]: Information of training set and
...
@@ -192,8 +193,10 @@ def _fill_trainval_infos(lyft,
...
@@ -192,8 +193,10 @@ def _fill_trainval_infos(lyft,
names
[
i
]
=
LyftDataset
.
NameMapping
[
names
[
i
]]
names
[
i
]
=
LyftDataset
.
NameMapping
[
names
[
i
]]
names
=
np
.
array
(
names
)
names
=
np
.
array
(
names
)
# we need to convert rot to SECOND format.
# we need to convert box size to
gt_boxes
=
np
.
concatenate
([
locs
,
dims
,
-
rots
-
np
.
pi
/
2
],
axis
=
1
)
# the format of our lidar coordinate system
# which is x_size, y_size, z_size (corresponding to l, w, h)
gt_boxes
=
np
.
concatenate
([
locs
,
dims
[:,
[
1
,
0
,
2
]],
rots
],
axis
=
1
)
assert
len
(
gt_boxes
)
==
len
(
assert
len
(
gt_boxes
)
==
len
(
annotations
),
f
'
{
len
(
gt_boxes
)
}
,
{
len
(
annotations
)
}
'
annotations
),
f
'
{
len
(
gt_boxes
)
}
,
{
len
(
annotations
)
}
'
info
[
'gt_boxes'
]
=
gt_boxes
info
[
'gt_boxes'
]
=
gt_boxes
...
...
tools/data_converter/lyft_data_fixer.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
argparse
import
numpy
as
np
import
os
import
os
import
numpy
as
np
def
fix_lyft
(
root_folder
=
'./data/lyft'
,
version
=
'v1.01'
):
def
fix_lyft
(
root_folder
=
'./data/lyft'
,
version
=
'v1.01'
):
# refer to https://www.kaggle.com/c/3d-object-detection-for-autonomous-vehicles/discussion/110000 # noqa
# refer to https://www.kaggle.com/c/3d-object-detection-for-autonomous-vehicles/discussion/110000 # noqa
...
...
tools/data_converter/nuimage_converter.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
argparse
import
base64
import
base64
from
os
import
path
as
osp
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
from
nuimages
import
NuImages
from
nuimages
import
NuImages
from
nuimages.utils.utils
import
mask_decode
,
name_to_index_mapping
from
nuimages.utils.utils
import
mask_decode
,
name_to_index_mapping
from
os
import
path
as
osp
nus_categories
=
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
nus_categories
=
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
'bicycle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'bicycle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
...
...
tools/data_converter/nuscenes_converter.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
os
import
path
as
osp
from
typing
import
List
,
Tuple
,
Union
import
mmcv
import
numpy
as
np
from
nuscenes.nuscenes
import
NuScenes
from
nuscenes.nuscenes
import
NuScenes
from
nuscenes.utils.geometry_utils
import
view_points
from
nuscenes.utils.geometry_utils
import
view_points
from
os
import
path
as
osp
from
pyquaternion
import
Quaternion
from
pyquaternion
import
Quaternion
from
shapely.geometry
import
MultiPoint
,
box
from
shapely.geometry
import
MultiPoint
,
box
from
typing
import
List
,
Tuple
,
Union
from
mmdet3d.core.bbox
.box_np_ops
import
points_cam2img
from
mmdet3d.core.bbox
import
points_cam2img
from
mmdet3d.datasets
import
NuScenesDataset
from
mmdet3d.datasets
import
NuScenesDataset
nus_categories
=
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
nus_categories
=
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
...
@@ -34,10 +35,10 @@ def create_nuscenes_infos(root_path,
...
@@ -34,10 +35,10 @@ def create_nuscenes_infos(root_path,
Args:
Args:
root_path (str): Path of the data root.
root_path (str): Path of the data root.
info_prefix (str): Prefix of the info file to be generated.
info_prefix (str): Prefix of the info file to be generated.
version (str): Version of the data.
version (str
, optional
): Version of the data.
Default: 'v1.0-trainval'
Default: 'v1.0-trainval'
.
max_sweeps (int): Max number of sweeps.
max_sweeps (int
, optional
): Max number of sweeps.
Default: 10
Default: 10
.
"""
"""
from
nuscenes.nuscenes
import
NuScenes
from
nuscenes.nuscenes
import
NuScenes
nusc
=
NuScenes
(
version
=
version
,
dataroot
=
root_path
,
verbose
=
True
)
nusc
=
NuScenes
(
version
=
version
,
dataroot
=
root_path
,
verbose
=
True
)
...
@@ -152,9 +153,9 @@ def _fill_trainval_infos(nusc,
...
@@ -152,9 +153,9 @@ def _fill_trainval_infos(nusc,
nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
train_scenes (list[str]): Basic information of training scenes.
train_scenes (list[str]): Basic information of training scenes.
val_scenes (list[str]): Basic information of validation scenes.
val_scenes (list[str]): Basic information of validation scenes.
test (bool): Whether use the test mode. In
the
test mode, no
test (bool
, optional
): Whether use the test mode. In test mode, no
annotations can be accessed. Default: False.
annotations can be accessed. Default: False.
max_sweeps (int): Max number of sweeps. Default: 10.
max_sweeps (int
, optional
): Max number of sweeps. Default: 10.
Returns:
Returns:
tuple[list[dict]]: Information of training set and validation set
tuple[list[dict]]: Information of training set and validation set
...
@@ -249,8 +250,10 @@ def _fill_trainval_infos(nusc,
...
@@ -249,8 +250,10 @@ def _fill_trainval_infos(nusc,
if
names
[
i
]
in
NuScenesDataset
.
NameMapping
:
if
names
[
i
]
in
NuScenesDataset
.
NameMapping
:
names
[
i
]
=
NuScenesDataset
.
NameMapping
[
names
[
i
]]
names
[
i
]
=
NuScenesDataset
.
NameMapping
[
names
[
i
]]
names
=
np
.
array
(
names
)
names
=
np
.
array
(
names
)
# we need to convert rot to SECOND format.
# we need to convert box size to
gt_boxes
=
np
.
concatenate
([
locs
,
dims
,
-
rots
-
np
.
pi
/
2
],
axis
=
1
)
# the format of our lidar coordinate system
# which is x_size, y_size, z_size (corresponding to l, w, h)
gt_boxes
=
np
.
concatenate
([
locs
,
dims
[:,
[
1
,
0
,
2
]],
rots
],
axis
=
1
)
assert
len
(
gt_boxes
)
==
len
(
assert
len
(
gt_boxes
)
==
len
(
annotations
),
f
'
{
len
(
gt_boxes
)
}
,
{
len
(
annotations
)
}
'
annotations
),
f
'
{
len
(
gt_boxes
)
}
,
{
len
(
annotations
)
}
'
info
[
'gt_boxes'
]
=
gt_boxes
info
[
'gt_boxes'
]
=
gt_boxes
...
@@ -289,7 +292,7 @@ def obtain_sensor2top(nusc,
...
@@ -289,7 +292,7 @@ def obtain_sensor2top(nusc,
e2g_t (np.ndarray): Translation from ego to global in shape (1, 3).
e2g_t (np.ndarray): Translation from ego to global in shape (1, 3).
e2g_r_mat (np.ndarray): Rotation matrix from ego to global
e2g_r_mat (np.ndarray): Rotation matrix from ego to global
in shape (3, 3).
in shape (3, 3).
sensor_type (str): Sensor to calibrate. Default: 'lidar'.
sensor_type (str
, optional
): Sensor to calibrate. Default: 'lidar'.
Returns:
Returns:
sweep (dict): Sweep information after transformation.
sweep (dict): Sweep information after transformation.
...
@@ -338,7 +341,8 @@ def export_2d_annotation(root_path, info_path, version, mono3d=True):
...
@@ -338,7 +341,8 @@ def export_2d_annotation(root_path, info_path, version, mono3d=True):
root_path (str): Root path of the raw data.
root_path (str): Root path of the raw data.
info_path (str): Path of the info file.
info_path (str): Path of the info file.
version (str): Dataset version.
version (str): Dataset version.
mono3d (bool): Whether to export mono3d annotation. Default: True.
mono3d (bool, optional): Whether to export mono3d annotation.
Default: True.
"""
"""
# get bbox annotations for camera
# get bbox annotations for camera
camera_types
=
[
camera_types
=
[
...
@@ -402,7 +406,7 @@ def get_2d_boxes(nusc,
...
@@ -402,7 +406,7 @@ def get_2d_boxes(nusc,
"""Get the 2D annotation records for a given `sample_data_token`.
"""Get the 2D annotation records for a given `sample_data_token`.
Args:
Args:
sample_data_token (str): Sample data token belonging to a camera
\
sample_data_token (str): Sample data token belonging to a camera
keyframe.
keyframe.
visibilities (list[str]): Visibility filter.
visibilities (list[str]): Visibility filter.
mono3d (bool): Whether to get boxes with mono3d annotation.
mono3d (bool): Whether to get boxes with mono3d annotation.
...
@@ -562,7 +566,7 @@ def post_process_coords(
...
@@ -562,7 +566,7 @@ def post_process_coords(
def
generate_record
(
ann_rec
:
dict
,
x1
:
float
,
y1
:
float
,
x2
:
float
,
y2
:
float
,
def
generate_record
(
ann_rec
:
dict
,
x1
:
float
,
y1
:
float
,
x2
:
float
,
y2
:
float
,
sample_data_token
:
str
,
filename
:
str
)
->
OrderedDict
:
sample_data_token
:
str
,
filename
:
str
)
->
OrderedDict
:
"""Generate one 2D annotation record given various information
s
on top of
"""Generate one 2D annotation record given various information on top of
the 2D bounding box coordinates.
the 2D bounding box coordinates.
Args:
Args:
...
@@ -577,7 +581,7 @@ def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
...
@@ -577,7 +581,7 @@ def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
Returns:
Returns:
dict: A sample 2D annotation record.
dict: A sample 2D annotation record.
- file_name (str): f
l
ie name
- file_name (str): fi
l
e name
- image_id (str): sample data token
- image_id (str): sample data token
- area (float): 2d box area
- area (float): 2d box area
- category_name (str): category name
- category_name (str): category name
...
...
tools/data_converter/s3dis_data_utils.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
import
os
import
os
from
concurrent
import
futures
as
futures
from
concurrent
import
futures
as
futures
from
os
import
path
as
osp
from
os
import
path
as
osp
import
mmcv
import
numpy
as
np
class
S3DISData
(
object
):
class
S3DISData
(
object
):
"""S3DIS data.
"""S3DIS data.
...
@@ -13,7 +14,7 @@ class S3DISData(object):
...
@@ -13,7 +14,7 @@ class S3DISData(object):
Args:
Args:
root_path (str): Root path of the raw data.
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'Area_1'.
split (str
, optional
): Set split type of the data. Default: 'Area_1'.
"""
"""
def
__init__
(
self
,
root_path
,
split
=
'Area_1'
):
def
__init__
(
self
,
root_path
,
split
=
'Area_1'
):
...
@@ -48,9 +49,11 @@ class S3DISData(object):
...
@@ -48,9 +49,11 @@ class S3DISData(object):
This method gets information from the raw data.
This method gets information from the raw data.
Args:
Args:
num_workers (int): Number of threads to be used. Default: 4.
num_workers (int, optional): Number of threads to be used.
has_label (bool): Whether the data has label. Default: True.
Default: 4.
sample_id_list (list[int]): Index list of the sample.
has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None.
Default: None.
Returns:
Returns:
...
@@ -154,10 +157,11 @@ class S3DISSegData(object):
...
@@ -154,10 +157,11 @@ class S3DISSegData(object):
Args:
Args:
data_root (str): Root path of the raw data.
data_root (str): Root path of the raw data.
ann_file (str): The generated scannet infos.
ann_file (str): The generated scannet infos.
split (str): Set split type of the data. Default: 'train'.
split (str, optional): Set split type of the data. Default: 'train'.
num_points (int): Number of points in each data input. Default: 8192.
num_points (int, optional): Number of points in each data input.
label_weight_func (function): Function to compute the label weight.
Default: 8192.
Default: None.
label_weight_func (function, optional): Function to compute the
label weight. Default: None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -209,7 +213,7 @@ class S3DISSegData(object):
...
@@ -209,7 +213,7 @@ class S3DISSegData(object):
return
label
return
label
def
get_scene_idxs_and_label_weight
(
self
):
def
get_scene_idxs_and_label_weight
(
self
):
"""Compute scene_idxs for data sampling and label weight for loss
\
"""Compute scene_idxs for data sampling and label weight for loss
calculation.
calculation.
We sample more times for scenes with more points. Label_weight is
We sample more times for scenes with more points. Label_weight is
...
...
tools/data_converter/scannet_data_utils.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
import
os
import
os
from
concurrent
import
futures
as
futures
from
concurrent
import
futures
as
futures
from
os
import
path
as
osp
from
os
import
path
as
osp
import
mmcv
import
numpy
as
np
class
ScanNetData
(
object
):
class
ScanNetData
(
object
):
"""ScanNet data.
"""ScanNet data.
...
@@ -13,7 +14,7 @@ class ScanNetData(object):
...
@@ -13,7 +14,7 @@ class ScanNetData(object):
Args:
Args:
root_path (str): Root path of the raw data.
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'.
split (str
, optional
): Set split type of the data. Default: 'train'.
"""
"""
def
__init__
(
self
,
root_path
,
split
=
'train'
):
def
__init__
(
self
,
root_path
,
split
=
'train'
):
...
@@ -90,9 +91,11 @@ class ScanNetData(object):
...
@@ -90,9 +91,11 @@ class ScanNetData(object):
This method gets information from the raw data.
This method gets information from the raw data.
Args:
Args:
num_workers (int): Number of threads to be used. Default: 4.
num_workers (int, optional): Number of threads to be used.
has_label (bool): Whether the data has label. Default: True.
Default: 4.
sample_id_list (list[int]): Index list of the sample.
has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None.
Default: None.
Returns:
Returns:
...
@@ -201,10 +204,11 @@ class ScanNetSegData(object):
...
@@ -201,10 +204,11 @@ class ScanNetSegData(object):
Args:
Args:
data_root (str): Root path of the raw data.
data_root (str): Root path of the raw data.
ann_file (str): The generated scannet infos.
ann_file (str): The generated scannet infos.
split (str): Set split type of the data. Default: 'train'.
split (str, optional): Set split type of the data. Default: 'train'.
num_points (int): Number of points in each data input. Default: 8192.
num_points (int, optional): Number of points in each data input.
label_weight_func (function): Function to compute the label weight.
Default: 8192.
Default: None.
label_weight_func (function, optional): Function to compute the
label weight. Default: None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -261,7 +265,7 @@ class ScanNetSegData(object):
...
@@ -261,7 +265,7 @@ class ScanNetSegData(object):
return
label
return
label
def
get_scene_idxs_and_label_weight
(
self
):
def
get_scene_idxs_and_label_weight
(
self
):
"""Compute scene_idxs for data sampling and label weight for loss
\
"""Compute scene_idxs for data sampling and label weight for loss
calculation.
calculation.
We sample more times for scenes with more points. Label_weight is
We sample more times for scenes with more points. Label_weight is
...
...
tools/data_converter/sunrgbd_data_utils.py
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
from
concurrent
import
futures
as
futures
from
concurrent
import
futures
as
futures
from
os
import
path
as
osp
from
os
import
path
as
osp
import
mmcv
import
numpy
as
np
from
scipy
import
io
as
sio
from
scipy
import
io
as
sio
...
@@ -42,18 +43,20 @@ class SUNRGBDInstance(object):
...
@@ -42,18 +43,20 @@ class SUNRGBDInstance(object):
self
.
ymax
=
data
[
2
]
+
data
[
4
]
self
.
ymax
=
data
[
2
]
+
data
[
4
]
self
.
box2d
=
np
.
array
([
self
.
xmin
,
self
.
ymin
,
self
.
xmax
,
self
.
ymax
])
self
.
box2d
=
np
.
array
([
self
.
xmin
,
self
.
ymin
,
self
.
xmax
,
self
.
ymax
])
self
.
centroid
=
np
.
array
([
data
[
5
],
data
[
6
],
data
[
7
]])
self
.
centroid
=
np
.
array
([
data
[
5
],
data
[
6
],
data
[
7
]])
self
.
w
=
data
[
8
]
self
.
width
=
data
[
8
]
self
.
l
=
data
[
9
]
# noqa: E741
self
.
length
=
data
[
9
]
self
.
h
=
data
[
10
]
self
.
height
=
data
[
10
]
# data[9] is x_size (length), data[8] is y_size (width), data[10] is
# z_size (height) in our depth coordinate system,
# l corresponds to the size along the x axis
self
.
size
=
np
.
array
([
data
[
9
],
data
[
8
],
data
[
10
]])
*
2
self
.
orientation
=
np
.
zeros
((
3
,
))
self
.
orientation
=
np
.
zeros
((
3
,
))
self
.
orientation
[
0
]
=
data
[
11
]
self
.
orientation
[
0
]
=
data
[
11
]
self
.
orientation
[
1
]
=
data
[
12
]
self
.
orientation
[
1
]
=
data
[
12
]
self
.
heading_angle
=
-
1
*
np
.
arctan2
(
self
.
orientation
[
1
],
self
.
heading_angle
=
np
.
arctan2
(
self
.
orientation
[
1
],
self
.
orientation
[
0
])
self
.
orientation
[
0
])
self
.
box3d
=
np
.
concatenate
([
self
.
box3d
=
np
.
concatenate
(
self
.
centroid
,
[
self
.
centroid
,
self
.
size
,
self
.
heading_angle
[
None
]])
np
.
array
([
self
.
l
*
2
,
self
.
w
*
2
,
self
.
h
*
2
,
self
.
heading_angle
])
])
class
SUNRGBDData
(
object
):
class
SUNRGBDData
(
object
):
...
@@ -63,8 +66,8 @@ class SUNRGBDData(object):
...
@@ -63,8 +66,8 @@ class SUNRGBDData(object):
Args:
Args:
root_path (str): Root path of the raw data.
root_path (str): Root path of the raw data.
split (str): Set split type of the data. Default: 'train'.
split (str
, optional
): Set split type of the data. Default: 'train'.
use_v1 (bool): Whether to use v1. Default: False.
use_v1 (bool
, optional
): Whether to use v1. Default: False.
"""
"""
def
__init__
(
self
,
root_path
,
split
=
'train'
,
use_v1
=
False
):
def
__init__
(
self
,
root_path
,
split
=
'train'
,
use_v1
=
False
):
...
@@ -129,9 +132,11 @@ class SUNRGBDData(object):
...
@@ -129,9 +132,11 @@ class SUNRGBDData(object):
This method gets information from the raw data.
This method gets information from the raw data.
Args:
Args:
num_workers (int): Number of threads to be used. Default: 4.
num_workers (int, optional): Number of threads to be used.
has_label (bool): Whether the data has label. Default: True.
Default: 4.
sample_id_list (list[int]): Index list of the sample.
has_label (bool, optional): Whether the data has label.
Default: True.
sample_id_list (list[int], optional): Index list of the sample.
Default: None.
Default: None.
Returns:
Returns:
...
@@ -192,7 +197,7 @@ class SUNRGBDData(object):
...
@@ -192,7 +197,7 @@ class SUNRGBDData(object):
],
],
axis
=
0
)
axis
=
0
)
annotations
[
'dimensions'
]
=
2
*
np
.
array
([
annotations
[
'dimensions'
]
=
2
*
np
.
array
([
[
obj
.
l
,
obj
.
w
,
obj
.
h
]
for
obj
in
obj_list
[
obj
.
l
ength
,
obj
.
w
idth
,
obj
.
h
eight
]
for
obj
in
obj_list
if
obj
.
classname
in
self
.
cat2label
.
keys
()
if
obj
.
classname
in
self
.
cat2label
.
keys
()
])
# lwh (depth) format
])
# lwh (depth) format
annotations
[
'rotation_y'
]
=
np
.
array
([
annotations
[
'rotation_y'
]
=
np
.
array
([
...
...
tools/data_converter/waymo_converter.py
View file @
32a4328b
...
@@ -10,11 +10,12 @@ except ImportError:
...
@@ -10,11 +10,12 @@ except ImportError:
'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
'to install the official devkit first.'
)
'to install the official devkit first.'
)
from
glob
import
glob
from
os.path
import
join
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
glob
import
glob
from
os.path
import
join
from
waymo_open_dataset.utils
import
range_image_utils
,
transform_utils
from
waymo_open_dataset.utils
import
range_image_utils
,
transform_utils
from
waymo_open_dataset.utils.frame_utils
import
\
from
waymo_open_dataset.utils.frame_utils
import
\
parse_range_image_and_camera_projection
parse_range_image_and_camera_projection
...
@@ -31,8 +32,8 @@ class Waymo2KITTI(object):
...
@@ -31,8 +32,8 @@ class Waymo2KITTI(object):
save_dir (str): Directory to save data in KITTI format.
save_dir (str): Directory to save data in KITTI format.
prefix (str): Prefix of filename. In general, 0 for training, 1 for
prefix (str): Prefix of filename. In general, 0 for training, 1 for
validation and 2 for testing.
validation and 2 for testing.
workers (
str
): Number of workers for the parallel process.
workers (
int, optional
): Number of workers for the parallel process.
test_mode (bool): Whether in the test_mode. Default: False.
test_mode (bool
, optional
): Whether in the test_mode. Default: False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -402,8 +403,8 @@ class Waymo2KITTI(object):
...
@@ -402,8 +403,8 @@ class Waymo2KITTI(object):
camera projections corresponding with two returns.
camera projections corresponding with two returns.
range_image_top_pose (:obj:`Transform`): Range image pixel pose for
range_image_top_pose (:obj:`Transform`): Range image pixel pose for
top lidar.
top lidar.
ri_index (int): 0 for the first return,
1 for the second return.
ri_index (int
, optional
): 0 for the first return,
Default: 0.
1 for the second return.
Default: 0.
Returns:
Returns:
tuple[list[np.ndarray]]: (List of points with shape [N, 3],
tuple[list[np.ndarray]]: (List of points with shape [N, 3],
...
...
tools/deployment/mmdet3d2torchserve.py
0 → 100644
View file @
32a4328b
# Copyright (c) OpenMMLab. All rights reserved.
from
argparse
import
ArgumentParser
,
Namespace
from
pathlib
import
Path
from
tempfile
import
TemporaryDirectory
import
mmcv
try
:
from
model_archiver.model_packaging
import
package_model
from
model_archiver.model_packaging_utils
import
ModelExportUtils
except
ImportError
:
package_model
=
None
def
mmdet3d2torchserve
(
config_file
:
str
,
checkpoint_file
:
str
,
output_folder
:
str
,
model_name
:
str
,
model_version
:
str
=
'1.0'
,
force
:
bool
=
False
,
):
"""Converts MMDetection3D model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file (str):
In MMDetection3D config format.
The contents vary for each task repository.
checkpoint_file (str):
In MMDetection3D checkpoint format.
The contents vary for each task repository.
output_folder (str):
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name (str):
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version (str, optional):
Model's version. Default: '1.0'.
force (bool, optional):
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
Default: False.
"""
mmcv
.
mkdir_or_exist
(
output_folder
)
config
=
mmcv
.
Config
.
fromfile
(
config_file
)
with
TemporaryDirectory
()
as
tmpdir
:
config
.
dump
(
f
'
{
tmpdir
}
/config.py'
)
args
=
Namespace
(
**
{
'model_file'
:
f
'
{
tmpdir
}
/config.py'
,
'serialized_file'
:
checkpoint_file
,
'handler'
:
f
'
{
Path
(
__file__
).
parent
}
/mmdet3d_handler.py'
,
'model_name'
:
model_name
or
Path
(
checkpoint_file
).
stem
,
'version'
:
model_version
,
'export_path'
:
output_folder
,
'force'
:
force
,
'requirements_file'
:
None
,
'extra_files'
:
None
,
'runtime'
:
'python'
,
'archive_format'
:
'default'
})
manifest
=
ModelExportUtils
.
generate_manifest_json
(
args
)
package_model
(
args
,
manifest
)
def
parse_args
():
parser
=
ArgumentParser
(
description
=
'Convert MMDetection models to TorchServe `.mar` format.'
)
parser
.
add_argument
(
'config'
,
type
=
str
,
help
=
'config file path'
)
parser
.
add_argument
(
'checkpoint'
,
type
=
str
,
help
=
'checkpoint file path'
)
parser
.
add_argument
(
'--output-folder'
,
type
=
str
,
required
=
True
,
help
=
'Folder where `{model_name}.mar` will be created.'
)
parser
.
add_argument
(
'--model-name'
,
type
=
str
,
default
=
None
,
help
=
'If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.'
)
parser
.
add_argument
(
'--model-version'
,
type
=
str
,
default
=
'1.0'
,
help
=
'Number used for versioning.'
)
parser
.
add_argument
(
'-f'
,
'--force'
,
action
=
'store_true'
,
help
=
'overwrite the existing `{model_name}.mar`'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
args
=
parse_args
()
if
package_model
is
None
:
raise
ImportError
(
'`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver'
)
mmdet3d2torchserve
(
args
.
config
,
args
.
checkpoint
,
args
.
output_folder
,
args
.
model_name
,
args
.
model_version
,
args
.
force
)
Prev
1
…
16
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment