Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Autonomous-Driving-models
Commits
19472568
Commit
19472568
authored
Apr 08, 2026
by
雍大凯
Browse files
将子模块转换为普通目录
parent
51e55208
Changes
233
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
3828 additions
and
0 deletions
+3828
-0
docker-hub/MapTRv2/MapTR/tools/maptrv2/av2_vis_pred.py
docker-hub/MapTRv2/MapTR/tools/maptrv2/av2_vis_pred.py
+511
-0
docker-hub/MapTRv2/MapTR/tools/maptrv2/custom_av2_map_converter.py
...b/MapTRv2/MapTR/tools/maptrv2/custom_av2_map_converter.py
+797
-0
docker-hub/MapTRv2/MapTR/tools/maptrv2/custom_nusc_map_converter.py
.../MapTRv2/MapTR/tools/maptrv2/custom_nusc_map_converter.py
+944
-0
docker-hub/MapTRv2/MapTR/tools/maptrv2/nusc_vis_pred.py
docker-hub/MapTRv2/MapTR/tools/maptrv2/nusc_vis_pred.py
+391
-0
docker-hub/MapTRv2/MapTR/tools/misc/browse_dataset.py
docker-hub/MapTRv2/MapTR/tools/misc/browse_dataset.py
+240
-0
docker-hub/MapTRv2/MapTR/tools/misc/fuse_conv_bn.py
docker-hub/MapTRv2/MapTR/tools/misc/fuse_conv_bn.py
+67
-0
docker-hub/MapTRv2/MapTR/tools/misc/print_config.py
docker-hub/MapTRv2/MapTR/tools/misc/print_config.py
+26
-0
docker-hub/MapTRv2/MapTR/tools/misc/visualize_results.py
docker-hub/MapTRv2/MapTR/tools/misc/visualize_results.py
+49
-0
docker-hub/MapTRv2/MapTR/tools/model_converters/convert_votenet_checkpoints.py
...pTR/tools/model_converters/convert_votenet_checkpoints.py
+152
-0
docker-hub/MapTRv2/MapTR/tools/model_converters/publish_model.py
...hub/MapTRv2/MapTR/tools/model_converters/publish_model.py
+35
-0
docker-hub/MapTRv2/MapTR/tools/model_converters/regnet2mmdet.py
...-hub/MapTRv2/MapTR/tools/model_converters/regnet2mmdet.py
+89
-0
docker-hub/MapTRv2/MapTR/tools/test.py
docker-hub/MapTRv2/MapTR/tools/test.py
+262
-0
docker-hub/MapTRv2/MapTR/tools/train.py
docker-hub/MapTRv2/MapTR/tools/train.py
+265
-0
No files found.
docker-hub/MapTRv2/MapTR/tools/maptrv2/av2_vis_pred.py
0 → 100644
View file @
19472568
import
argparse
import
mmcv
import
os
import
shutil
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
)
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.datasets
import
build_dataset
import
sys
sys
.
path
.
append
(
''
)
from
projects.mmdet3d_plugin.datasets.builder
import
build_dataloader
from
mmdet3d.models
import
build_model
from
mmdet.apis
import
set_random_seed
from
projects.mmdet3d_plugin.bevformer.apis.test
import
custom_multi_gpu_test
from
mmdet.datasets
import
replace_ImageToTensor
import
time
import
os.path
as
osp
import
numpy
as
np
from
PIL
import
Image
import
matplotlib.pyplot
as
plt
from
matplotlib
import
transforms
from
matplotlib.patches
import
Rectangle
from
shapely.geometry
import
LineString
import
cv2
import
copy
caption_by_cam
=
{
'ring_front_center'
:
'CAM_FRONT_CENTER'
,
'ring_front_right'
:
'CAM_FRONT_RIGHT'
,
'ring_front_left'
:
'CAM_FRONT_LEFT'
,
'ring_rear_right'
:
'CAM_REAR_RIGHT'
,
'ring_rear_left'
:
'CAM_REAT_LEFT'
,
'ring_side_right'
:
'CAM_SIDE_RIGHT'
,
'ring_side_left'
:
'CAM_SIDE_LEFT'
,
}
COLOR_MAPS_BGR
=
{
# bgr colors
'divider'
:
(
54
,
137
,
255
),
'boundary'
:
(
0
,
0
,
255
),
'ped_crossing'
:
(
255
,
0
,
0
),
'centerline'
:
(
0
,
255
,
0
),
'drivable_area'
:
(
171
,
255
,
255
)
}
data_path_prefix
=
'/home/users/yunchi.zhang/project/MapTR'
# project root
def
remove_nan_values
(
uv
):
is_u_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
0
]))
is_v_valid
=
np
.
logical_not
(
np
.
isnan
(
uv
[:,
1
]))
is_uv_valid
=
np
.
logical_and
(
is_u_valid
,
is_v_valid
)
uv_valid
=
uv
[
is_uv_valid
]
return
uv_valid
def
interp_fixed_dist
(
line
,
sample_dist
):
''' Interpolate a line at fixed interval.
Args:
line (LineString): line
sample_dist (float): sample interval
Returns:
points (array): interpolated points, shape (N, 2)
'''
distances
=
list
(
np
.
arange
(
sample_dist
,
line
.
length
,
sample_dist
))
# make sure to sample at least two points when sample_dist > line.length
distances
=
[
0
,]
+
distances
+
[
line
.
length
,]
sampled_points
=
np
.
array
([
list
(
line
.
interpolate
(
distance
).
coords
)
for
distance
in
distances
]).
squeeze
()
return
sampled_points
def
draw_visible_polyline_cv2
(
line
,
valid_pts_bool
,
image
,
color
,
thickness_px
,
map_class
):
"""Draw a polyline onto an image using given line segments.
Args:
line: Array of shape (K, 2) representing the coordinates of line.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
For example, if the coordinate is occluded, a user might specify that it is invalid.
Line segments touching an invalid vertex will not be rendered.
image: Array of shape (H, W, 3), representing a 3-channel BGR image
color: Tuple of shape (3,) with a BGR format color
thickness_px: thickness (in pixels) to use when rendering the polyline.
"""
line
=
np
.
round
(
line
).
astype
(
int
)
# type: ignore
# if map_class == 'centerline':
# instance = LineString(line).simplify(0.2, preserve_topology=True)
# line = np.array(list(instance.coords))
# line = np.round(line).astype(int)
for
i
in
range
(
len
(
line
)
-
1
):
if
(
not
valid_pts_bool
[
i
])
or
(
not
valid_pts_bool
[
i
+
1
]):
continue
x1
=
line
[
i
][
0
]
y1
=
line
[
i
][
1
]
x2
=
line
[
i
+
1
][
0
]
y2
=
line
[
i
+
1
][
1
]
# Use anti-aliasing (AA) for curves
if
map_class
!=
'centerline'
:
image
=
cv2
.
line
(
image
,
pt1
=
(
x1
,
y1
),
pt2
=
(
x2
,
y2
),
color
=
color
,
thickness
=
thickness_px
,
lineType
=
cv2
.
LINE_AA
)
else
:
image
=
cv2
.
arrowedLine
(
image
,(
x1
,
y1
),(
x2
,
y2
),
color
,
thickness_px
,
8
,
0
,
0.7
)
def
points_ego2img
(
pts_ego
,
lidar2img
):
pts_ego_4d
=
np
.
concatenate
([
pts_ego
,
np
.
ones
([
len
(
pts_ego
),
1
])],
axis
=-
1
)
pts_img_4d
=
lidar2img
@
pts_ego_4d
.
T
uv
=
pts_img_4d
.
T
uv
=
remove_nan_values
(
uv
)
depth
=
uv
[:,
2
]
uv
=
uv
[:,
:
2
]
/
uv
[:,
2
].
reshape
(
-
1
,
1
)
return
uv
,
depth
def
draw_polyline_ego_on_img
(
polyline_ego
,
img_bgr
,
lidar2img
,
map_class
,
thickness
):
# if 2-dimension, assume z=0
if
polyline_ego
.
shape
[
1
]
==
2
:
zeros
=
np
.
zeros
((
polyline_ego
.
shape
[
0
],
1
))
polyline_ego
=
np
.
concatenate
([
polyline_ego
,
zeros
],
axis
=
1
)
polyline_ego
=
interp_fixed_dist
(
line
=
LineString
(
polyline_ego
),
sample_dist
=
0.2
)
uv
,
depth
=
points_ego2img
(
polyline_ego
,
lidar2img
)
h
,
w
,
c
=
img_bgr
.
shape
is_valid_x
=
np
.
logical_and
(
0
<=
uv
[:,
0
],
uv
[:,
0
]
<
w
-
1
)
is_valid_y
=
np
.
logical_and
(
0
<=
uv
[:,
1
],
uv
[:,
1
]
<
h
-
1
)
is_valid_z
=
depth
>
0
is_valid_points
=
np
.
logical_and
.
reduce
([
is_valid_x
,
is_valid_y
,
is_valid_z
])
if
is_valid_points
.
sum
()
==
0
:
return
tmp_list
=
[]
for
i
,
valid
in
enumerate
(
is_valid_points
):
if
valid
:
tmp_list
.
append
(
uv
[
i
])
else
:
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
color
=
COLOR_MAPS_BGR
[
map_class
],
thickness_px
=
thickness
,
map_class
=
map_class
)
tmp_list
=
[]
if
len
(
tmp_list
)
>=
2
:
tmp_vector
=
np
.
stack
(
tmp_list
)
tmp_vector
=
np
.
round
(
tmp_vector
).
astype
(
np
.
int32
)
draw_visible_polyline_cv2
(
copy
.
deepcopy
(
tmp_vector
),
valid_pts_bool
=
np
.
ones
((
len
(
uv
),
1
),
dtype
=
bool
),
image
=
img_bgr
,
color
=
COLOR_MAPS_BGR
[
map_class
],
thickness_px
=
thickness
,
map_class
=
map_class
,
)
def
render_anno_on_pv
(
cam_img
,
anno
,
lidar2img
):
for
key
,
value
in
anno
.
items
():
for
pts
in
value
:
draw_polyline_ego_on_img
(
pts
,
cam_img
,
lidar2img
,
key
,
thickness
=
10
)
def
perspective
(
cam_coords
,
proj_mat
):
pix_coords
=
proj_mat
@
cam_coords
valid_idx
=
pix_coords
[
2
,
:]
>
0
pix_coords
=
pix_coords
[:,
valid_idx
]
pix_coords
=
pix_coords
[:
2
,
:]
/
(
pix_coords
[
2
,
:]
+
1e-7
)
pix_coords
=
pix_coords
.
transpose
(
1
,
0
)
return
pix_coords
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'vis hdmaptr map gt label'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--score-thresh'
,
default
=
0.4
,
type
=
float
,
help
=
'samples to visualize'
)
parser
.
add_argument
(
'--show-dir'
,
help
=
'directory where visualizations will be saved'
)
parser
.
add_argument
(
'--show-cam'
,
action
=
'store_true'
,
help
=
'show camera pic'
)
parser
.
add_argument
(
'--gt-format'
,
type
=
str
,
nargs
=
'+'
,
default
=
[
'fixed_num_pts'
,],
help
=
'vis format, default should be "points",'
'support ["se_pts","bbox","fixed_num_pts","polyline_pts"]'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
# import modules from plguin/xx, registry will be updated
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
plugin_dir
=
cfg
.
plugin_dir
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
cfg
.
model
.
pretrained
=
None
# in case the test dataset is concatenated
samples_per_gpu
=
1
if
isinstance
(
cfg
.
data
.
test
,
dict
):
cfg
.
data
.
test
.
test_mode
=
True
samples_per_gpu
=
cfg
.
data
.
test
.
pop
(
'samples_per_gpu'
,
1
)
if
samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
test
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
test
.
pipeline
)
elif
isinstance
(
cfg
.
data
.
test
,
list
):
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
test_mode
=
True
samples_per_gpu
=
max
(
[
ds_cfg
.
pop
(
'samples_per_gpu'
,
1
)
for
ds_cfg
in
cfg
.
data
.
test
])
if
samples_per_gpu
>
1
:
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
pipeline
=
replace_ImageToTensor
(
ds_cfg
.
pipeline
)
if
args
.
show_dir
is
None
:
args
.
show_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
],
'vis_pred'
)
# create vis_label dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
args
.
show_dir
))
cfg
.
dump
(
osp
.
join
(
args
.
show_dir
,
osp
.
basename
(
args
.
config
)))
logger
=
get_root_logger
()
logger
.
info
(
f
'DONE create vis_pred dir:
{
args
.
show_dir
}
'
)
dataset
=
build_dataset
(
cfg
.
data
.
test
)
dataset
.
is_vis_on_test
=
True
#TODO, this is a hack
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
samples_per_gpu
,
# workers_per_gpu=cfg.data.workers_per_gpu,
workers_per_gpu
=
0
,
dist
=
False
,
shuffle
=
False
,
nonshuffler_sampler
=
cfg
.
data
.
nonshuffler_sampler
,
)
logger
.
info
(
'Done build test data set'
)
# build the model and load checkpoint
# import pdb;pdb.set_trace()
cfg
.
model
.
train_cfg
=
None
# cfg.model.pts_bbox_head.bbox_coder.max_num=15 # TODO this is a hack
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
logger
.
info
(
'loading check point'
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
if
'CLASSES'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
CLASSES
=
checkpoint
[
'meta'
][
'CLASSES'
]
else
:
model
.
CLASSES
=
dataset
.
CLASSES
# palette for visualization in segmentation tasks
if
'PALETTE'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
PALETTE
=
checkpoint
[
'meta'
][
'PALETTE'
]
elif
hasattr
(
dataset
,
'PALETTE'
):
# segmentation dataset has `PALETTE` attribute
model
.
PALETTE
=
dataset
.
PALETTE
logger
.
info
(
'DONE load check point'
)
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
model
.
eval
()
img_norm_cfg
=
cfg
.
img_norm_cfg
# get denormalized param
mean
=
np
.
array
(
img_norm_cfg
[
'mean'
],
dtype
=
np
.
float32
)
std
=
np
.
array
(
img_norm_cfg
[
'std'
],
dtype
=
np
.
float32
)
to_bgr
=
img_norm_cfg
[
'to_rgb'
]
# get pc_range
pc_range
=
cfg
.
point_cloud_range
# get car icon
car_img
=
Image
.
open
(
'./figs/car.png'
)
# get color map: divider->orange, ped->blue, boundary->red, centerline->green
colors_plt
=
[
'orange'
,
'blue'
,
'red'
,
'green'
]
logger
.
info
(
'BEGIN vis test dataset samples gt label & pred'
)
bbox_results
=
[]
mask_results
=
[]
dataset
=
data_loader
.
dataset
have_mask
=
False
# prog_bar = mmcv.ProgressBar(len(CANDIDATE))
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
# import pdb;pdb.set_trace()
final_dict
=
{}
for
i
,
data
in
enumerate
(
data_loader
):
if
~
(
data
[
'gt_labels_3d'
].
data
[
0
][
0
]
!=
-
1
).
any
():
# import pdb;pdb.set_trace()
logger
.
error
(
f
'
\n
empty gt for index
{
i
}
, continue'
)
# prog_bar.update()
continue
img
=
data
[
'img'
][
0
].
data
[
0
]
img_metas
=
data
[
'img_metas'
][
0
].
data
[
0
]
gt_bboxes_3d
=
data
[
'gt_bboxes_3d'
].
data
[
0
]
gt_labels_3d
=
data
[
'gt_labels_3d'
].
data
[
0
]
pts_filename
=
img_metas
[
0
][
'pts_filename'
]
pts_filename
=
osp
.
basename
(
pts_filename
)
pts_filename
=
pts_filename
.
split
(
'.'
)[
0
]
# import pdb;pdb.set_trace()
# if pts_filename not in CANDIDATE:
# continue
sample_dict
=
{}
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
sample_dir
=
osp
.
join
(
args
.
show_dir
,
pts_filename
)
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
sample_dir
))
filename_list
=
img_metas
[
0
][
'filename'
]
img_path_dict
=
{}
# save cam img for sample
# import ipdb;ipdb.set_trace()
for
filepath
,
lidar2img
,
img_aug
in
zip
(
filename_list
,
img_metas
[
0
][
'lidar2img'
],
img_metas
[
0
][
'img_aug_matrix'
]):
inv_aug
=
np
.
linalg
.
inv
(
img_aug
)
lidar2orimg
=
np
.
dot
(
inv_aug
,
lidar2img
)
cam_name
=
os
.
path
.
dirname
(
filepath
).
split
(
'/'
)[
-
1
]
img_path_dict
[
cam_name
]
=
dict
(
filepath
=
filepath
,
lidar2img
=
lidar2orimg
)
sample_dict
[
'imgs_path'
]
=
img_path_dict
gt_dict
=
{
'divider'
:[],
'ped_crossing'
:[],
'boundary'
:[],
'centerline'
:[]}
# import ipdb;ipdb.set_trace()
gt_lines_instance
=
gt_bboxes_3d
[
0
].
instance_list
# import pdb;pdb.set_trace()
for
gt_line_instance
,
gt_label_3d
in
zip
(
gt_lines_instance
,
gt_labels_3d
[
0
]):
if
gt_label_3d
==
0
:
gt_dict
[
'divider'
].
append
(
np
.
array
(
list
(
gt_line_instance
.
coords
)))
elif
gt_label_3d
==
1
:
gt_dict
[
'ped_crossing'
].
append
(
np
.
array
(
list
(
gt_line_instance
.
coords
)))
elif
gt_label_3d
==
2
:
gt_dict
[
'boundary'
].
append
(
np
.
array
(
list
(
gt_line_instance
.
coords
)))
elif
gt_label_3d
==
3
:
gt_dict
[
'centerline'
].
append
(
np
.
array
(
list
(
gt_line_instance
.
coords
)))
else
:
raise
NotImplementedError
sample_dict
[
'gt_map'
]
=
gt_dict
result_dict
=
result
[
0
][
'pts_bbox'
]
sample_dict
[
'pred_map'
]
=
result_dict
# visualize gt
plt
.
figure
(
figsize
=
(
4
,
2
))
plt
.
xlim
(
-
30
,
30
)
plt
.
ylim
(
-
15
,
15
)
plt
.
axis
(
'off'
)
gt_centerlines
=
[]
for
pts
in
gt_dict
[
'divider'
]:
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
plt
.
plot
(
x
,
y
,
color
=
'orange'
,
linewidth
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
for
pts
in
gt_dict
[
'ped_crossing'
]:
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
plt
.
plot
(
x
,
y
,
color
=
'blue'
,
linewidth
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
for
pts
in
gt_dict
[
'boundary'
]:
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
plt
.
plot
(
x
,
y
,
color
=
'red'
,
linewidth
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
for
pts
in
gt_dict
[
'centerline'
]:
instance
=
LineString
(
pts
).
simplify
(
0.2
,
preserve_topology
=
True
)
pts
=
np
.
array
(
list
(
instance
.
coords
))
gt_centerlines
.
append
(
pts
)
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
plt
.
quiver
(
x
[:
-
1
],
y
[:
-
1
],
x
[
1
:]
-
x
[:
-
1
],
y
[
1
:]
-
y
[:
-
1
],
scale_units
=
'xy'
,
angles
=
'xy'
,
scale
=
1
,
color
=
'green'
,
headwidth
=
5
,
headlength
=
6
,
width
=
0.006
,
alpha
=
0.8
,
zorder
=-
1
)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.5
,
1.5
,
-
1.2
,
1.2
])
gt_map_path
=
osp
.
join
(
sample_dir
,
'GT_MAP.png'
)
plt
.
savefig
(
gt_map_path
,
bbox_inches
=
'tight'
,
format
=
'png'
,
dpi
=
1200
)
plt
.
close
()
# visualize pred
scores_3d
=
result_dict
[
'scores_3d'
]
labels_3d
=
result_dict
[
'labels_3d'
]
pts_3d
=
result_dict
[
'pts_3d'
]
keep
=
scores_3d
>
0.3
plt
.
figure
(
figsize
=
(
4
,
2
))
plt
.
xlim
(
-
30
,
30
)
plt
.
ylim
(
-
15
,
15
)
plt
.
axis
(
'off'
)
pred_centerlines
=
[]
pred_anno
=
{
'divider'
:[],
'ped_crossing'
:[],
'boundary'
:[],
'centerline'
:[]}
class_by_index
=
[
'divider'
,
'ped_crossing'
,
'boundary'
]
for
pred_score_3d
,
pred_label_3d
,
pred_pts_3d
in
zip
(
scores_3d
[
keep
],
labels_3d
[
keep
],
pts_3d
[
keep
]):
if
pred_label_3d
==
3
:
instance
=
LineString
(
pred_pts_3d
.
numpy
()).
simplify
(
0.2
,
preserve_topology
=
True
)
pts
=
np
.
array
(
list
(
instance
.
coords
))
pred_anno
[
'centerline'
].
append
(
pts
)
pred_centerlines
.
append
(
pts
)
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
plt
.
quiver
(
x
[:
-
1
],
y
[:
-
1
],
x
[
1
:]
-
x
[:
-
1
],
y
[
1
:]
-
y
[:
-
1
],
scale_units
=
'xy'
,
angles
=
'xy'
,
scale
=
1
,
color
=
'green'
,
headwidth
=
5
,
headlength
=
6
,
width
=
0.006
,
alpha
=
0.8
,
zorder
=-
1
)
else
:
pred_pts_3d
=
pred_pts_3d
.
numpy
()
pred_anno
[
class_by_index
[
pred_label_3d
]].
append
(
pred_pts_3d
)
pts_x
=
pred_pts_3d
[:,
0
]
pts_y
=
pred_pts_3d
[:,
1
]
plt
.
plot
(
pts_x
,
pts_y
,
color
=
colors_plt
[
pred_label_3d
],
linewidth
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
# plt.scatter(pts_x, pts_y, color=colors_plt[pred_label_3d],s=1,alpha=0.8,zorder=-1)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.5
,
1.5
,
-
1.2
,
1.2
])
map_path
=
osp
.
join
(
sample_dir
,
'PRED_MAP.png'
)
plt
.
savefig
(
map_path
,
bbox_inches
=
'tight'
,
format
=
'png'
,
dpi
=
1200
)
plt
.
close
()
rendered_cams_dict
=
{}
for
key
,
cam_dict
in
img_path_dict
.
items
():
cam_img
=
cv2
.
imread
(
osp
.
join
(
data_path_prefix
,
cam_dict
[
'filepath'
]))
render_anno_on_pv
(
cam_img
,
pred_anno
,
cam_dict
[
'lidar2img'
])
if
'front'
not
in
key
:
# cam_img = cam_img[:,::-1,:]
cam_img
=
cv2
.
flip
(
cam_img
,
1
)
lw
=
8
tf
=
max
(
lw
-
1
,
1
)
w
,
h
=
cv2
.
getTextSize
(
caption_by_cam
[
key
],
0
,
fontScale
=
lw
/
3
,
thickness
=
tf
)[
0
]
# text width, height
p1
=
(
0
,
0
)
p2
=
(
w
,
h
+
3
)
color
=
(
0
,
0
,
0
)
txt_color
=
(
255
,
255
,
255
)
cv2
.
rectangle
(
cam_img
,
p1
,
p2
,
color
,
-
1
,
cv2
.
LINE_AA
)
# filled
cv2
.
putText
(
cam_img
,
caption_by_cam
[
key
],
(
p1
[
0
],
p1
[
1
]
+
h
+
2
),
0
,
lw
/
3
,
txt_color
,
thickness
=
tf
,
lineType
=
cv2
.
LINE_AA
)
rendered_cams_dict
[
key
]
=
cam_img
new_image_height
=
2048
new_image_width
=
1550
+
2048
*
2
color
=
(
255
,
255
,
255
)
first_row_canvas
=
np
.
full
((
new_image_height
,
new_image_width
,
3
),
color
,
dtype
=
np
.
uint8
)
first_row_canvas
[(
2048
-
1550
):,
:
2048
,:]
=
rendered_cams_dict
[
'ring_front_left'
]
first_row_canvas
[:,
2048
:(
2048
+
1550
),:]
=
rendered_cams_dict
[
'ring_front_center'
]
first_row_canvas
[(
2048
-
1550
):,
3598
:,:]
=
rendered_cams_dict
[
'ring_front_right'
]
new_image_height
=
1550
new_image_width
=
2048
*
4
color
=
(
255
,
255
,
255
)
second_row_canvas
=
np
.
full
((
new_image_height
,
new_image_width
,
3
),
color
,
dtype
=
np
.
uint8
)
second_row_canvas
[:,:
2048
,:]
=
rendered_cams_dict
[
'ring_side_left'
]
second_row_canvas
[:,
2048
:
4096
,:]
=
rendered_cams_dict
[
'ring_rear_left'
]
second_row_canvas
[:,
4096
:
6144
,:]
=
rendered_cams_dict
[
'ring_rear_right'
]
second_row_canvas
[:,
6144
:,:]
=
rendered_cams_dict
[
'ring_side_right'
]
resized_first_row_canvas
=
cv2
.
resize
(
first_row_canvas
,(
8192
,
2972
))
full_canvas
=
np
.
full
((
2972
+
1550
,
8192
,
3
),
color
,
dtype
=
np
.
uint8
)
full_canvas
[:
2972
,:,:]
=
resized_first_row_canvas
full_canvas
[
2972
:,:,:]
=
second_row_canvas
cams_img_path
=
osp
.
join
(
sample_dir
,
'surroud_view.jpg'
)
cv2
.
imwrite
(
cams_img_path
,
full_canvas
,[
cv2
.
IMWRITE_JPEG_QUALITY
,
70
])
final_dict
[
pts_filename
]
=
sample_dict
prog_bar
.
update
()
mmcv
.
dump
(
final_dict
,
osp
.
join
(
args
.
show_dir
,
'final_dict.pkl'
))
logger
.
info
(
'
\n
DONE vis test dataset samples gt label & pred'
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/maptrv2/custom_av2_map_converter.py
0 → 100644
View file @
19472568
from
functools
import
partial
from
multiprocessing
import
Pool
import
multiprocessing
from
random
import
sample
import
time
import
mmcv
import
logging
from
pathlib
import
Path
from
os
import
path
as
osp
import
os
from
av2.datasets.sensor.av2_sensor_dataloader
import
AV2SensorDataLoader
from
av2.map.lane_segment
import
LaneMarkType
,
LaneSegment
from
av2.map.map_api
import
ArgoverseStaticMap
from
tqdm
import
tqdm
import
argparse
import
networkx
as
nx
from
av2.map.map_primitives
import
Polyline
from
nuscenes.map_expansion.map_api
import
NuScenesMapExplorer
from
shapely
import
affinity
,
ops
from
shapely.geometry
import
Polygon
,
LineString
,
box
,
MultiPolygon
,
MultiLineString
from
shapely.strtree
import
STRtree
from
nuscenes.eval.common.utils
import
quaternion_yaw
,
Quaternion
from
av2.geometry.se3
import
SE3
import
numpy
as
np
import
math
from
shapely.geometry
import
CAP_STYLE
,
JOIN_STYLE
from
scipy.spatial
import
distance
import
warnings
warnings
.
filterwarnings
(
"ignore"
)
CAM_NAMES
=
[
'ring_front_center'
,
'ring_front_right'
,
'ring_front_left'
,
'ring_rear_right'
,
'ring_rear_left'
,
'ring_side_right'
,
'ring_side_left'
,
# 'stereo_front_left', 'stereo_front_right',
]
# some fail logs as stated in av2
# https://github.com/argoverse/av2-api/blob/05b7b661b7373adb5115cf13378d344d2ee43906/src/av2/map/README.md#training-online-map-inference-models
FAIL_LOGS
=
[
# official
'75e8adad-50a6-3245-8726-5e612db3d165'
,
'54bc6dbc-ebfb-3fba-b5b3-57f88b4b79ca'
,
'af170aac-8465-3d7b-82c5-64147e94af7d'
,
'6e106cf8-f6dd-38f6-89c8-9be7a71e7275'
,
# observed
'01bb304d-7bd8-35f8-bbef-7086b688e35e'
,
'453e5558-6363-38e3-bf9b-42b5ba0a6f1d'
]
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Data converter arg parser'
)
parser
.
add_argument
(
'--data-root'
,
type
=
str
,
help
=
'specify the root path of dataset'
)
parser
.
add_argument
(
'--pc-range'
,
type
=
float
,
nargs
=
'+'
,
default
=
[
-
30.0
,
-
15.0
,
-
5.0
,
30.0
,
15.0
,
3.0
],
help
=
'specify the perception point cloud range'
)
parser
.
add_argument
(
'--nproc'
,
type
=
int
,
default
=
64
,
required
=
False
,
help
=
'workers to process data'
)
args
=
parser
.
parse_args
()
return
args
def
create_av2_infos_mp
(
root_path
,
info_prefix
,
dest_path
=
None
,
split
=
'train'
,
num_multithread
=
64
,
pc_range
=
[
-
30.0
,
-
15.0
,
-
5.0
,
30.0
,
15.0
,
3.0
]):
"""Create info file of av2 dataset.
Given the raw data, generate its related info file in pkl format.
Args:
root_path (str): Path of the data root.
info_prefix (str): Prefix of the info file to be generated.
dest_path (str): Path to store generated file, default to root_path
split (str): Split of the data.
Default: 'train'
"""
root_path
=
osp
.
join
(
root_path
,
split
)
if
dest_path
is
None
:
dest_path
=
root_path
loader
=
AV2SensorDataLoader
(
Path
(
root_path
),
Path
(
root_path
))
log_ids
=
list
(
loader
.
get_log_ids
())
# import pdb;pdb.set_trace()
for
l
in
FAIL_LOGS
:
if
l
in
log_ids
:
log_ids
.
remove
(
l
)
print
(
'collecting samples...'
)
start_time
=
time
.
time
()
print
(
'num cpu:'
,
multiprocessing
.
cpu_count
())
print
(
f
'using
{
num_multithread
}
threads'
)
# to supress logging from av2.utils.synchronization_database
sdb_logger
=
logging
.
getLogger
(
'av2.utils.synchronization_database'
)
prev_level
=
sdb_logger
.
level
sdb_logger
.
setLevel
(
logging
.
CRITICAL
)
# FIXME: need to check the order
pool
=
Pool
(
num_multithread
)
fn
=
partial
(
get_data_from_logid
,
loader
=
loader
,
data_root
=
root_path
,
pc_range
=
pc_range
)
rt
=
pool
.
map_async
(
fn
,
log_ids
)
pool
.
close
()
pool
.
join
()
results
=
rt
.
get
()
samples
=
[]
discarded
=
0
sample_idx
=
0
for
_samples
,
_discarded
in
results
:
for
i
in
range
(
len
(
_samples
)):
_samples
[
i
][
'sample_idx'
]
=
sample_idx
sample_idx
+=
1
samples
+=
_samples
discarded
+=
_discarded
sdb_logger
.
setLevel
(
prev_level
)
print
(
f
'
{
len
(
samples
)
}
available samples,
{
discarded
}
samples discarded'
)
print
(
'collected in {}s'
.
format
(
time
.
time
()
-
start_time
))
infos
=
dict
(
samples
=
samples
)
info_path
=
osp
.
join
(
dest_path
,
'{}_map_infos_{}.pkl'
.
format
(
info_prefix
,
split
))
print
(
f
'saving results to
{
info_path
}
'
)
mmcv
.
dump
(
infos
,
info_path
)
# mmcv.dump(samples, info_path)
def
get_divider
(
avm
):
divider_list
=
[]
for
ls
in
avm
.
get_scenario_lane_segments
():
for
bound_type
,
bound_city
in
zip
([
ls
.
left_mark_type
,
ls
.
right_mark_type
],
[
ls
.
left_lane_boundary
,
ls
.
right_lane_boundary
]):
if
bound_type
not
in
[
LaneMarkType
.
NONE
,]:
divider_list
.
append
(
bound_city
.
xyz
)
return
divider_list
def
get_boundary
(
avm
):
boundary_list
=
[]
for
da
in
avm
.
get_scenario_vector_drivable_areas
():
boundary_list
.
append
(
da
.
xyz
)
return
boundary_list
def
get_ped
(
avm
):
ped_list
=
[]
for
pc
in
avm
.
get_scenario_ped_crossings
():
ped_list
.
append
(
pc
.
polygon
)
return
ped_list
def
get_data_from_logid
(
log_id
,
loader
:
AV2SensorDataLoader
,
data_root
,
pc_range
=
[
-
30.0
,
-
15.0
,
-
5.0
,
30.0
,
15.0
,
3.0
]):
samples
=
[]
discarded
=
0
log_map_dirpath
=
Path
(
osp
.
join
(
data_root
,
log_id
,
"map"
))
vector_data_fnames
=
sorted
(
log_map_dirpath
.
glob
(
"log_map_archive_*.json"
))
if
not
len
(
vector_data_fnames
)
==
1
:
raise
RuntimeError
(
f
"JSON file containing vector map data is missing (searched in
{
log_map_dirpath
}
)"
)
vector_data_fname
=
vector_data_fnames
[
0
]
vector_data_json_path
=
vector_data_fname
avm
=
ArgoverseStaticMap
.
from_json
(
vector_data_json_path
)
# We use lidar timestamps to query all sensors.
# The frequency is 10Hz
cam_timestamps
=
loader
.
_sdb
.
per_log_lidar_timestamps_index
[
log_id
]
for
ts
in
cam_timestamps
:
cam_ring_fpath
=
[
loader
.
get_closest_img_fpath
(
log_id
,
cam_name
,
ts
)
for
cam_name
in
CAM_NAMES
]
lidar_fpath
=
loader
.
get_closest_lidar_fpath
(
log_id
,
ts
)
# If bad sensor synchronization, discard the sample
if
None
in
cam_ring_fpath
or
lidar_fpath
is
None
:
discarded
+=
1
continue
cams
=
{}
for
i
,
cam_name
in
enumerate
(
CAM_NAMES
):
pinhole_cam
=
loader
.
get_log_pinhole_camera
(
log_id
,
cam_name
)
cam_timestamp_ns
=
int
(
cam_ring_fpath
[
i
].
stem
)
cam_city_SE3_ego
=
loader
.
get_city_SE3_ego
(
log_id
,
cam_timestamp_ns
)
cams
[
cam_name
]
=
dict
(
img_fpath
=
str
(
cam_ring_fpath
[
i
]),
intrinsics
=
pinhole_cam
.
intrinsics
.
K
,
extrinsics
=
pinhole_cam
.
extrinsics
,
e2g_translation
=
cam_city_SE3_ego
.
translation
,
e2g_rotation
=
cam_city_SE3_ego
.
rotation
,
)
city_SE3_ego
=
loader
.
get_city_SE3_ego
(
log_id
,
int
(
ts
))
e2g_translation
=
city_SE3_ego
.
translation
e2g_rotation
=
city_SE3_ego
.
rotation
info
=
dict
(
e2g_translation
=
e2g_translation
,
e2g_rotation
=
e2g_rotation
,
cams
=
cams
,
lidar_path
=
str
(
lidar_fpath
),
# map_fpath=map_fname,
timestamp
=
str
(
ts
),
log_id
=
log_id
,
token
=
str
(
log_id
+
'_'
+
str
(
ts
)))
map_anno
=
extract_local_map
(
avm
,
e2g_translation
,
e2g_rotation
,
pc_range
)
info
[
"annotation"
]
=
map_anno
samples
.
append
(
info
)
return
samples
,
discarded
def
extract_local_map
(
avm
,
e2g_translation
,
e2g_rotation
,
pc_range
):
patch_h
=
pc_range
[
4
]
-
pc_range
[
1
]
patch_w
=
pc_range
[
3
]
-
pc_range
[
0
]
patch_size
=
(
patch_h
,
patch_w
)
map_pose
=
e2g_translation
[:
2
]
rotation
=
Quaternion
.
_from_matrix
(
e2g_rotation
)
patch_box
=
(
map_pose
[
0
],
map_pose
[
1
],
patch_size
[
0
],
patch_size
[
1
])
patch_angle
=
quaternion_yaw
(
rotation
)
/
np
.
pi
*
180
city_SE2_ego
=
SE3
(
e2g_rotation
,
e2g_translation
)
ego_SE3_city
=
city_SE2_ego
.
inverse
()
nearby_centerlines
=
generate_nearby_centerlines
(
avm
,
patch_box
,
patch_angle
)
nearby_dividers
=
generate_nearby_dividers
(
avm
,
patch_box
,
patch_angle
)
map_anno
=
dict
(
divider
=
[],
ped_crossing
=
[],
boundary
=
[],
centerline
=
[],
)
map_anno
[
'ped_crossing'
]
=
extract_local_ped_crossing
(
avm
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
)
map_anno
[
'boundary'
]
=
extract_local_boundary
(
avm
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
)
map_anno
[
'centerline'
]
=
extract_local_centerline
(
nearby_centerlines
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
)
map_anno
[
'divider'
]
=
extract_local_divider
(
nearby_dividers
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
)
return
map_anno
def
generate_nearby_centerlines
(
avm
,
patch_box
,
patch_angle
):
patch
=
NuScenesMapExplorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
scene_ls_list
=
avm
.
get_scenario_lane_segments
()
scene_ls_dict
=
dict
()
for
ls
in
scene_ls_list
:
scene_ls_dict
[
ls
.
id
]
=
dict
(
ls
=
ls
,
polygon
=
Polygon
(
ls
.
polygon_boundary
),
predecessors
=
ls
.
predecessors
,
successors
=
ls
.
successors
)
ls_dict
=
dict
()
for
key
,
value
in
scene_ls_dict
.
items
():
polygon
=
value
[
'polygon'
]
if
polygon
.
is_valid
:
new_polygon
=
polygon
.
intersection
(
patch
)
if
not
new_polygon
.
is_empty
:
ls_dict
[
key
]
=
value
for
key
,
value
in
ls_dict
.
items
():
value
[
'centerline'
]
=
Polyline
.
from_array
(
avm
.
get_lane_segment_centerline
(
key
).
round
(
3
))
pts_G
=
nx
.
DiGraph
()
junction_pts_list
=
[]
tmp
=
ls_dict
for
key
,
value
in
tmp
.
items
():
centerline_geom
=
LineString
(
value
[
'centerline'
].
xyz
)
centerline_pts
=
np
.
array
(
centerline_geom
.
coords
).
round
(
3
)
start_pt
=
centerline_pts
[
0
]
end_pt
=
centerline_pts
[
-
1
]
for
idx
,
pts
in
enumerate
(
centerline_pts
[:
-
1
]):
pts_G
.
add_edge
(
tuple
(
centerline_pts
[
idx
]),
tuple
(
centerline_pts
[
idx
+
1
]))
valid_incoming_num
=
0
for
idx
,
pred
in
enumerate
(
value
[
'predecessors'
]):
if
pred
in
tmp
.
keys
():
valid_incoming_num
+=
1
pred_geom
=
LineString
(
tmp
[
pred
][
'centerline'
].
xyz
)
pred_pt
=
np
.
array
(
pred_geom
.
coords
).
round
(
3
)[
-
1
]
pts_G
.
add_edge
(
tuple
(
pred_pt
),
tuple
(
start_pt
))
if
valid_incoming_num
>
1
:
junction_pts_list
.
append
(
tuple
(
start_pt
))
valid_outgoing_num
=
0
for
idx
,
succ
in
enumerate
(
value
[
'successors'
]):
if
succ
in
tmp
.
keys
():
valid_outgoing_num
+=
1
succ_geom
=
LineString
(
tmp
[
succ
][
'centerline'
].
xyz
)
succ_pt
=
np
.
array
(
succ_geom
.
coords
).
round
(
3
)[
0
]
pts_G
.
add_edge
(
tuple
(
end_pt
),
tuple
(
succ_pt
))
if
valid_outgoing_num
>
1
:
junction_pts_list
.
append
(
tuple
(
end_pt
))
roots
=
(
v
for
v
,
d
in
pts_G
.
in_degree
()
if
d
==
0
)
leaves
=
[
v
for
v
,
d
in
pts_G
.
out_degree
()
if
d
==
0
]
all_paths
=
[]
for
root
in
roots
:
paths
=
nx
.
all_simple_paths
(
pts_G
,
root
,
leaves
)
all_paths
.
extend
(
paths
)
final_centerline_paths
=
[]
for
path
in
all_paths
:
merged_line
=
LineString
(
path
)
merged_line
=
merged_line
.
simplify
(
0.2
,
preserve_topology
=
True
)
final_centerline_paths
.
append
(
merged_line
)
local_centerline_paths
=
final_centerline_paths
return
local_centerline_paths
def
generate_nearby_dividers
(
avm
,
patch_box
,
patch_angle
):
def
get_path
(
ls_dict
):
pts_G
=
nx
.
DiGraph
()
junction_pts_list
=
[]
tmp
=
ls_dict
for
key
,
value
in
tmp
.
items
():
centerline_geom
=
LineString
(
value
[
'centerline'
].
xyz
)
centerline_pts
=
np
.
array
(
centerline_geom
.
coords
).
round
(
3
)
start_pt
=
centerline_pts
[
0
]
end_pt
=
centerline_pts
[
-
1
]
for
idx
,
pts
in
enumerate
(
centerline_pts
[:
-
1
]):
pts_G
.
add_edge
(
tuple
(
centerline_pts
[
idx
]),
tuple
(
centerline_pts
[
idx
+
1
]))
valid_incoming_num
=
0
for
idx
,
pred
in
enumerate
(
value
[
'predecessors'
]):
if
pred
in
tmp
.
keys
():
valid_incoming_num
+=
1
pred_geom
=
LineString
(
tmp
[
pred
][
'centerline'
].
xyz
)
pred_pt
=
np
.
array
(
pred_geom
.
coords
).
round
(
3
)[
-
1
]
pts_G
.
add_edge
(
tuple
(
pred_pt
),
tuple
(
start_pt
))
if
valid_incoming_num
>
1
:
junction_pts_list
.
append
(
tuple
(
start_pt
))
valid_outgoing_num
=
0
for
idx
,
succ
in
enumerate
(
value
[
'successors'
]):
if
succ
in
tmp
.
keys
():
valid_outgoing_num
+=
1
succ_geom
=
LineString
(
tmp
[
succ
][
'centerline'
].
xyz
)
succ_pt
=
np
.
array
(
succ_geom
.
coords
).
round
(
3
)[
0
]
pts_G
.
add_edge
(
tuple
(
end_pt
),
tuple
(
succ_pt
))
if
valid_outgoing_num
>
1
:
junction_pts_list
.
append
(
tuple
(
end_pt
))
roots
=
(
v
for
v
,
d
in
pts_G
.
in_degree
()
if
d
==
0
)
leaves
=
[
v
for
v
,
d
in
pts_G
.
out_degree
()
if
d
==
0
]
all_paths
=
[]
for
root
in
roots
:
paths
=
nx
.
all_simple_paths
(
pts_G
,
root
,
leaves
)
all_paths
.
extend
(
paths
)
final_centerline_paths
=
[]
for
path
in
all_paths
:
merged_line
=
LineString
(
path
)
merged_line
=
merged_line
.
simplify
(
0.2
,
preserve_topology
=
True
)
final_centerline_paths
.
append
(
merged_line
)
local_centerline_paths
=
final_centerline_paths
return
local_centerline_paths
patch
=
NuScenesMapExplorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
scene_ls_list
=
avm
.
get_scenario_lane_segments
()
scene_ls_dict
=
dict
()
for
ls
in
scene_ls_list
:
scene_ls_dict
[
ls
.
id
]
=
dict
(
ls
=
ls
,
polygon
=
Polygon
(
ls
.
polygon_boundary
),
predecessors
=
ls
.
predecessors
,
successors
=
ls
.
successors
)
# nearby_ls_ids = []
nearby_ls_dict
=
dict
()
for
key
,
value
in
scene_ls_dict
.
items
():
polygon
=
value
[
'polygon'
]
if
polygon
.
is_valid
:
new_polygon
=
polygon
.
intersection
(
patch
)
if
not
new_polygon
.
is_empty
:
nearby_ls_dict
[
key
]
=
value
[
'ls'
]
ls_dict
=
nearby_ls_dict
divider_ls_dict
=
dict
()
for
key
,
value
in
ls_dict
.
items
():
if
not
value
.
is_intersection
:
divider_ls_dict
[
key
]
=
value
left_lane_dict
=
{}
right_lane_dict
=
{}
for
key
,
value
in
divider_ls_dict
.
items
():
if
value
.
left_neighbor_id
is
not
None
:
left_lane_dict
[
key
]
=
dict
(
polyline
=
value
.
left_lane_boundary
,
predecessors
=
value
.
predecessors
,
successors
=
value
.
successors
,
left_neighbor_id
=
value
.
left_neighbor_id
,
)
if
value
.
right_neighbor_id
is
not
None
:
right_lane_dict
[
key
]
=
dict
(
polyline
=
value
.
right_lane_boundary
,
predecessors
=
value
.
predecessors
,
successors
=
value
.
successors
,
right_neighbor_id
=
value
.
right_neighbor_id
,
)
for
key
,
value
in
left_lane_dict
.
items
():
if
value
[
'left_neighbor_id'
]
in
right_lane_dict
.
keys
():
del
right_lane_dict
[
value
[
'left_neighbor_id'
]]
for
key
,
value
in
right_lane_dict
.
items
():
if
value
[
'right_neighbor_id'
]
in
left_lane_dict
.
keys
():
del
left_lane_dict
[
value
[
'right_neighbor_id'
]]
for
key
,
value
in
left_lane_dict
.
items
():
value
[
'centerline'
]
=
value
[
'polyline'
]
for
key
,
value
in
right_lane_dict
.
items
():
value
[
'centerline'
]
=
value
[
'polyline'
]
left_paths
=
get_path
(
left_lane_dict
)
right_paths
=
get_path
(
right_lane_dict
)
local_dividers
=
left_paths
+
right_paths
return
local_dividers
def
proc_polygon
(
polygon
,
ego_SE3_city
):
# import pdb;pdb.set_trace()
interiors
=
[]
exterior_cityframe
=
np
.
array
(
list
(
polygon
.
exterior
.
coords
))
exterior_egoframe
=
ego_SE3_city
.
transform_point_cloud
(
exterior_cityframe
)
for
inter
in
polygon
.
interiors
:
inter_cityframe
=
np
.
array
(
list
(
inter
.
coords
))
inter_egoframe
=
ego_SE3_city
.
transform_point_cloud
(
inter_cityframe
)
interiors
.
append
(
inter_egoframe
[:,:
3
])
new_polygon
=
Polygon
(
exterior_egoframe
[:,:
3
],
interiors
)
return
new_polygon
def
proc_line
(
line
,
ego_SE3_city
):
# import pdb;pdb.set_trace()
new_line_pts_cityframe
=
np
.
array
(
list
(
line
.
coords
))
new_line_pts_egoframe
=
ego_SE3_city
.
transform_point_cloud
(
new_line_pts_cityframe
)
line
=
LineString
(
new_line_pts_egoframe
[:,:
3
])
#TODO
return
line
def
extract_local_centerline
(
nearby_centerlines
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
):
patch
=
NuScenesMapExplorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
line_list
=
[]
for
line
in
nearby_centerlines
:
if
line
.
is_empty
:
# Skip lines without nodes.
continue
new_line
=
line
.
intersection
(
patch
)
if
not
new_line
.
is_empty
:
if
new_line
.
geom_type
==
'MultiLineString'
:
for
single_line
in
new_line
.
geoms
:
if
single_line
.
is_empty
:
continue
single_line
=
proc_line
(
single_line
,
ego_SE3_city
)
line_list
.
append
(
single_line
)
else
:
new_line
=
proc_line
(
new_line
,
ego_SE3_city
)
line_list
.
append
(
new_line
)
centerlines
=
line_list
poly_centerlines
=
[
line
.
buffer
(
1
,
cap_style
=
CAP_STYLE
.
flat
,
join_style
=
JOIN_STYLE
.
mitre
)
for
line
in
centerlines
]
index_by_id
=
dict
((
id
(
pt
),
i
)
for
i
,
pt
in
enumerate
(
poly_centerlines
))
tree
=
STRtree
(
poly_centerlines
)
final_pgeom
=
[]
remain_idx
=
[
i
for
i
in
range
(
len
(
centerlines
))]
for
i
,
pline
in
enumerate
(
poly_centerlines
):
if
i
not
in
remain_idx
:
continue
remain_idx
.
pop
(
remain_idx
.
index
(
i
))
final_pgeom
.
append
(
centerlines
[
i
])
for
o
in
tree
.
query
(
pline
):
o_idx
=
index_by_id
[
id
(
o
)]
if
o_idx
not
in
remain_idx
:
continue
inter
=
o
.
intersection
(
pline
).
area
union
=
o
.
union
(
pline
).
area
iou
=
inter
/
union
if
iou
>=
0.90
:
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
return
[
np
.
array
(
line
.
coords
)
for
line
in
final_pgeom
]
def
merge_dividers
(
divider_list
):
# divider_list: List[np.array(N,3)]
if
len
(
divider_list
)
<
2
:
return
divider_list
divider_list_shapely
=
[
LineString
(
divider
)
for
divider
in
divider_list
]
poly_dividers
=
[
divider
.
buffer
(
1
,
cap_style
=
CAP_STYLE
.
flat
,
join_style
=
JOIN_STYLE
.
mitre
)
for
divider
in
divider_list_shapely
]
tree
=
STRtree
(
poly_dividers
)
index_by_id
=
dict
((
id
(
pt
),
i
)
for
i
,
pt
in
enumerate
(
poly_dividers
))
final_pgeom
=
[]
remain_idx
=
[
i
for
i
in
range
(
len
(
poly_dividers
))]
for
i
,
pline
in
enumerate
(
poly_dividers
):
if
i
not
in
remain_idx
:
continue
remain_idx
.
pop
(
remain_idx
.
index
(
i
))
final_pgeom
.
append
(
divider_list
[
i
])
for
o
in
tree
.
query
(
pline
):
o_idx
=
index_by_id
[
id
(
o
)]
if
o_idx
not
in
remain_idx
:
continue
# remove highly overlap divider
inter
=
o
.
intersection
(
pline
).
area
o_iof
=
inter
/
o
.
area
p_iof
=
inter
/
pline
.
area
# if query divider is highly overlaped with latter dividers, just remove it
if
p_iof
>=
0.95
:
final_pgeom
.
pop
()
break
# if queried divider is highly overlapped with query divider,
# drop it and just turn to next one.
if
o_iof
>=
0.95
:
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
continue
pline_se_pts
=
final_pgeom
[
-
1
][[
0
,
-
1
],:
2
]
# only on xy
o_se_pts
=
divider_list
[
o_idx
][[
0
,
-
1
],:
2
]
# only on xy
four_se_pts
=
np
.
concatenate
([
pline_se_pts
,
o_se_pts
],
axis
=
0
)
dist_mat
=
distance
.
cdist
(
four_se_pts
,
four_se_pts
,
'euclidean'
)
for
j
in
range
(
4
):
dist_mat
[
j
,
j
]
=
100
index
=
np
.
where
(
dist_mat
==
0
)[
0
].
tolist
()
if
index
==
[
0
,
2
]:
# e oline s s pline e
# +-------+ +-------+
final_pgeom
[
-
1
]
=
np
.
concatenate
([
np
.
flip
(
divider_list
[
o_idx
],
axis
=
0
)[:
-
1
],
final_pgeom
[
-
1
]])
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
elif
index
==
[
1
,
2
]:
# s pline e s oline e
# +-------+ +-------+
final_pgeom
[
-
1
]
=
np
.
concatenate
([
final_pgeom
[
-
1
][:
-
1
],
divider_list
[
o_idx
]])
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
elif
index
==
[
0
,
3
]:
# s oline e s pline e
# +-------+ +-------+
final_pgeom
[
-
1
]
=
np
.
concatenate
([
divider_list
[
o_idx
][:
-
1
],
final_pgeom
[
-
1
]])
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
elif
index
==
[
1
,
3
]:
# s pline e e oline s
# +-------+ +-------+
final_pgeom
[
-
1
]
=
np
.
concatenate
([
final_pgeom
[
-
1
][:
-
1
],
np
.
flip
(
divider_list
[
o_idx
],
axis
=
0
)])
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
elif
len
(
index
)
>
2
:
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
return
final_pgeom
def
extract_local_divider
(
nearby_dividers
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
):
patch
=
NuScenesMapExplorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
line_list
=
[]
for
line
in
nearby_dividers
:
if
line
.
is_empty
:
# Skip lines without nodes.
continue
new_line
=
line
.
intersection
(
patch
)
if
not
new_line
.
is_empty
:
if
new_line
.
geom_type
==
'MultiLineString'
:
for
single_line
in
new_line
.
geoms
:
if
single_line
.
is_empty
:
continue
single_line
=
proc_line
(
single_line
,
ego_SE3_city
)
line_list
.
append
(
single_line
)
else
:
new_line
=
proc_line
(
new_line
,
ego_SE3_city
)
line_list
.
append
(
new_line
)
centerlines
=
line_list
poly_centerlines
=
[
line
.
buffer
(
1
,
cap_style
=
CAP_STYLE
.
flat
,
join_style
=
JOIN_STYLE
.
mitre
)
for
line
in
centerlines
]
index_by_id
=
dict
((
id
(
pt
),
i
)
for
i
,
pt
in
enumerate
(
poly_centerlines
))
tree
=
STRtree
(
poly_centerlines
)
final_pgeom
=
[]
remain_idx
=
[
i
for
i
in
range
(
len
(
centerlines
))]
for
i
,
pline
in
enumerate
(
poly_centerlines
):
if
i
not
in
remain_idx
:
continue
remain_idx
.
pop
(
remain_idx
.
index
(
i
))
final_pgeom
.
append
(
centerlines
[
i
])
for
o
in
tree
.
query
(
pline
):
o_idx
=
index_by_id
[
id
(
o
)]
if
o_idx
not
in
remain_idx
:
continue
inter
=
o
.
intersection
(
pline
).
area
union
=
o
.
union
(
pline
).
area
iou
=
inter
/
union
if
iou
>=
0.90
:
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
return
[
np
.
array
(
line
.
coords
)
for
line
in
final_pgeom
]
def
extract_local_boundary
(
avm
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
):
boundary_list
=
[]
patch
=
NuScenesMapExplorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
for
da
in
avm
.
get_scenario_vector_drivable_areas
():
boundary_list
.
append
(
da
.
xyz
)
polygon_list
=
[]
for
da
in
boundary_list
:
exterior_coords
=
da
interiors
=
[]
# polygon = Polygon(exterior_coords, interiors)
polygon
=
Polygon
(
exterior_coords
,
interiors
)
if
polygon
.
is_valid
:
new_polygon
=
polygon
.
intersection
(
patch
)
if
not
new_polygon
.
is_empty
:
if
new_polygon
.
geom_type
is
'Polygon'
:
if
not
new_polygon
.
is_valid
:
continue
new_polygon
=
proc_polygon
(
new_polygon
,
ego_SE3_city
)
if
not
new_polygon
.
is_valid
:
continue
elif
new_polygon
.
geom_type
is
'MultiPolygon'
:
polygons
=
[]
for
single_polygon
in
new_polygon
.
geoms
:
if
not
single_polygon
.
is_valid
or
single_polygon
.
is_empty
:
continue
new_single_polygon
=
proc_polygon
(
single_polygon
,
ego_SE3_city
)
if
not
new_single_polygon
.
is_valid
:
continue
polygons
.
append
(
new_single_polygon
)
if
len
(
polygons
)
==
0
:
continue
new_polygon
=
MultiPolygon
(
polygons
)
if
not
new_polygon
.
is_valid
:
continue
else
:
raise
ValueError
(
'{} is not valid'
.
format
(
new_polygon
.
geom_type
))
if
new_polygon
.
geom_type
is
'Polygon'
:
new_polygon
=
MultiPolygon
([
new_polygon
])
polygon_list
.
append
(
new_polygon
)
union_segments
=
ops
.
unary_union
(
polygon_list
)
max_x
=
patch_size
[
1
]
/
2
max_y
=
patch_size
[
0
]
/
2
local_patch
=
box
(
-
max_x
+
0.2
,
-
max_y
+
0.2
,
max_x
-
0.2
,
max_y
-
0.2
)
exteriors
=
[]
interiors
=
[]
if
union_segments
.
geom_type
!=
'MultiPolygon'
:
union_segments
=
MultiPolygon
([
union_segments
])
for
poly
in
union_segments
.
geoms
:
exteriors
.
append
(
poly
.
exterior
)
for
inter
in
poly
.
interiors
:
interiors
.
append
(
inter
)
results
=
[]
for
ext
in
exteriors
:
if
ext
.
is_ccw
:
ext
.
coords
=
list
(
ext
.
coords
)[::
-
1
]
lines
=
ext
.
intersection
(
local_patch
)
if
isinstance
(
lines
,
MultiLineString
):
lines
=
ops
.
linemerge
(
lines
)
results
.
append
(
lines
)
for
inter
in
interiors
:
if
not
inter
.
is_ccw
:
inter
.
coords
=
list
(
inter
.
coords
)[::
-
1
]
lines
=
inter
.
intersection
(
local_patch
)
if
isinstance
(
lines
,
MultiLineString
):
lines
=
ops
.
linemerge
(
lines
)
results
.
append
(
lines
)
boundary_lines
=
[]
for
line
in
results
:
if
not
line
.
is_empty
:
if
line
.
geom_type
==
'MultiLineString'
:
for
single_line
in
line
.
geoms
:
boundary_lines
.
append
(
np
.
array
(
single_line
.
coords
))
elif
line
.
geom_type
==
'LineString'
:
boundary_lines
.
append
(
np
.
array
(
line
.
coords
))
else
:
raise
NotImplementedError
return
boundary_lines
def
extract_local_ped_crossing
(
avm
,
ego_SE3_city
,
patch_box
,
patch_angle
,
patch_size
):
ped_list
=
[]
for
pc
in
avm
.
get_scenario_ped_crossings
():
ped_list
.
append
(
pc
.
polygon
)
patch
=
NuScenesMapExplorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
polygon_list
=
[]
for
pc
in
ped_list
:
exterior_coords
=
pc
interiors
=
[]
polygon
=
Polygon
(
exterior_coords
,
interiors
)
if
polygon
.
is_valid
:
new_polygon
=
polygon
.
intersection
(
patch
)
if
not
new_polygon
.
is_empty
:
if
new_polygon
.
geom_type
is
'Polygon'
:
if
not
new_polygon
.
is_valid
:
continue
new_polygon
=
proc_polygon
(
new_polygon
,
ego_SE3_city
)
if
not
new_polygon
.
is_valid
:
continue
elif
new_polygon
.
geom_type
is
'MultiPolygon'
:
polygons
=
[]
for
single_polygon
in
new_polygon
.
geoms
:
if
not
single_polygon
.
is_valid
or
single_polygon
.
is_empty
:
continue
new_single_polygon
=
proc_polygon
(
single_polygon
,
ego_SE3_city
)
if
not
new_single_polygon
.
is_valid
:
continue
polygons
.
append
(
new_single_polygon
)
if
len
(
polygons
)
==
0
:
continue
new_polygon
=
MultiPolygon
(
polygons
)
if
not
new_polygon
.
is_valid
:
continue
else
:
raise
ValueError
(
'{} is not valid'
.
format
(
new_polygon
.
geom_type
))
if
new_polygon
.
geom_type
is
'Polygon'
:
new_polygon
=
MultiPolygon
([
new_polygon
])
polygon_list
.
append
(
new_polygon
)
def
get_rec_direction
(
geom
):
rect
=
geom
.
minimum_rotated_rectangle
# polygon as rotated rect
rect_v_p
=
np
.
array
(
rect
.
exterior
.
coords
)[:
3
]
# vector point
rect_v
=
rect_v_p
[
1
:]
-
rect_v_p
[:
-
1
]
# vector
v_len
=
np
.
linalg
.
norm
(
rect_v
,
axis
=-
1
)
# vector length
longest_v_i
=
v_len
.
argmax
()
return
rect_v
[
longest_v_i
],
v_len
[
longest_v_i
]
ped_geoms
=
polygon_list
tree
=
STRtree
(
ped_geoms
)
index_by_id
=
dict
((
id
(
pt
),
i
)
for
i
,
pt
in
enumerate
(
ped_geoms
))
final_pgeom
=
[]
remain_idx
=
[
i
for
i
in
range
(
len
(
ped_geoms
))]
for
i
,
pgeom
in
enumerate
(
ped_geoms
):
if
i
not
in
remain_idx
:
continue
remain_idx
.
pop
(
remain_idx
.
index
(
i
))
pgeom_v
,
pgeom_v_norm
=
get_rec_direction
(
pgeom
)
final_pgeom
.
append
(
pgeom
)
for
o
in
tree
.
query
(
pgeom
):
o_idx
=
index_by_id
[
id
(
o
)]
if
o_idx
not
in
remain_idx
:
continue
o_v
,
o_v_norm
=
get_rec_direction
(
o
)
cos
=
pgeom_v
.
dot
(
o_v
)
/
(
pgeom_v_norm
*
o_v_norm
)
if
1
-
np
.
abs
(
cos
)
<
0.01
:
# theta < 8 degrees.
final_pgeom
[
-
1
]
=
\
final_pgeom
[
-
1
].
union
(
o
)
# union parallel ped?
# update
remain_idx
.
pop
(
remain_idx
.
index
(
o_idx
))
for
i
in
range
(
len
(
final_pgeom
)):
if
final_pgeom
[
i
].
geom_type
!=
'MultiPolygon'
:
final_pgeom
[
i
]
=
MultiPolygon
([
final_pgeom
[
i
]])
max_x
=
patch_size
[
1
]
/
2
max_y
=
patch_size
[
0
]
/
2
local_patch
=
box
(
-
max_x
+
0.2
,
-
max_y
+
0.2
,
max_x
-
0.2
,
max_y
-
0.2
)
# results = []
results
=
[]
for
geom
in
final_pgeom
:
for
ped_poly
in
geom
.
geoms
:
# rect = ped_poly.minimum_rotated_rectangle
ext
=
ped_poly
.
exterior
if
not
ext
.
is_ccw
:
ext
.
coords
=
list
(
ext
.
coords
)[::
-
1
]
lines
=
ext
.
intersection
(
local_patch
)
if
lines
.
type
!=
'LineString'
:
lines
=
ops
.
linemerge
(
lines
)
# same instance but not connected.
if
lines
.
type
!=
'LineString'
:
ls
=
[]
for
l
in
lines
.
geoms
:
ls
.
append
(
np
.
array
(
l
.
coords
))
lines
=
np
.
concatenate
(
ls
,
axis
=
0
)
lines
=
LineString
(
lines
)
results
.
append
(
np
.
array
(
lines
.
coords
))
return
results
if
__name__
==
'__main__'
:
args
=
parse_args
()
for
name
in
[
'train'
,
'val'
,
'test'
]:
create_av2_infos_mp
(
root_path
=
args
.
data_root
,
split
=
name
,
info_prefix
=
'av2'
,
dest_path
=
args
.
data_root
,
pc_range
=
args
.
pc_range
,)
\ No newline at end of file
docker-hub/MapTRv2/MapTR/tools/maptrv2/custom_nusc_map_converter.py
0 → 100644
View file @
19472568
import
argparse
from
os
import
path
as
osp
import
sys
import
mmcv
import
numpy
as
np
import
os
from
collections
import
OrderedDict
from
nuscenes.nuscenes
import
NuScenes
from
nuscenes.utils.geometry_utils
import
view_points
from
os
import
path
as
osp
# from pyquaternion import Quaternion
from
shapely.geometry
import
MultiPoint
,
box
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
mmdet3d.core.bbox.box_np_ops
import
points_cam2img
from
mmdet3d.datasets
import
NuScenesDataset
from
nuscenes.map_expansion.map_api
import
NuScenesMap
,
NuScenesMapExplorer
from
nuscenes.eval.common.utils
import
quaternion_yaw
,
Quaternion
from
nuscenes.map_expansion.bitmap
import
BitMap
from
matplotlib.patches
import
Polygon
as
mPolygon
from
shapely
import
affinity
,
ops
# from shapely.geometry import LineString, box, MultiPolygon, MultiLineString
from
shapely.geometry
import
Polygon
,
MultiPolygon
,
LineString
,
Point
,
box
,
MultiLineString
from
matplotlib.axes
import
Axes
from
matplotlib.figure
import
Figure
import
networkx
as
nx
sys
.
path
.
append
(
'.'
)
class
CNuScenesMapExplorer
(
NuScenesMapExplorer
):
def
__ini__
(
self
,
*
args
,
**
kwargs
):
super
(
self
,
CNuScenesMapExplorer
).
__init__
(
*
args
,
**
kwargs
)
def
_get_centerline
(
self
,
patch_box
:
Tuple
[
float
,
float
,
float
,
float
],
patch_angle
:
float
,
layer_name
:
str
,
return_token
:
bool
=
False
)
->
dict
:
"""
Retrieve the centerline of a particular layer within the specified patch.
:param patch_box: Patch box defined as [x_center, y_center, height, width].
:param patch_angle: Patch orientation in degrees.
:param layer_name: name of map layer to be extracted.
:return: dict(token:record_dict, token:record_dict,...)
"""
if
layer_name
not
in
[
'lane'
,
'lane_connector'
]:
raise
ValueError
(
'{} is not a centerline layer'
.
format
(
layer_name
))
patch_x
=
patch_box
[
0
]
patch_y
=
patch_box
[
1
]
patch
=
self
.
get_patch_coord
(
patch_box
,
patch_angle
)
records
=
getattr
(
self
.
map_api
,
layer_name
)
centerline_dict
=
dict
()
for
record
in
records
:
if
record
[
'polygon_token'
]
is
None
:
# import ipdb
# ipdb.set_trace()
continue
polygon
=
self
.
map_api
.
extract_polygon
(
record
[
'polygon_token'
])
# if polygon.intersects(patch) or polygon.within(patch):
# if not polygon.is_valid:
# print('within: {}, intersect: {}'.format(polygon.within(patch), polygon.intersects(patch)))
# print('polygon token {} is_valid: {}'.format(record['polygon_token'], polygon.is_valid))
# polygon = polygon.buffer(0)
if
polygon
.
is_valid
:
# if within or intersect :
new_polygon
=
polygon
.
intersection
(
patch
)
# new_polygon = polygon
if
not
new_polygon
.
is_empty
:
centerline
=
self
.
map_api
.
discretize_lanes
(
record
,
0.5
)
centerline
=
list
(
self
.
map_api
.
discretize_lanes
([
record
[
'token'
]],
0.5
).
values
())[
0
]
centerline
=
LineString
(
np
.
array
(
centerline
)[:,:
2
].
round
(
3
))
if
centerline
.
is_empty
:
continue
centerline
=
centerline
.
intersection
(
patch
)
if
not
centerline
.
is_empty
:
centerline
=
\
to_patch_coord
(
centerline
,
patch_angle
,
patch_x
,
patch_y
)
# centerline.coords = np.array(centerline.coords).round(3)
# if centerline.geom_type != 'LineString':
# import ipdb;ipdb.set_trace()
record_dict
=
dict
(
centerline
=
centerline
,
token
=
record
[
'token'
],
incoming_tokens
=
self
.
map_api
.
get_incoming_lane_ids
(
record
[
'token'
]),
outgoing_tokens
=
self
.
map_api
.
get_outgoing_lane_ids
(
record
[
'token'
]),
)
centerline_dict
.
update
({
record
[
'token'
]:
record_dict
})
return
centerline_dict
def
to_patch_coord
(
new_polygon
,
patch_angle
,
patch_x
,
patch_y
):
new_polygon
=
affinity
.
rotate
(
new_polygon
,
-
patch_angle
,
origin
=
(
patch_x
,
patch_y
),
use_radians
=
False
)
new_polygon
=
affinity
.
affine_transform
(
new_polygon
,
[
1.0
,
0.0
,
0.0
,
1.0
,
-
patch_x
,
-
patch_y
])
return
new_polygon
def
get_available_scenes
(
nusc
):
"""Get available scenes from the input nuscenes class.
Given the raw data, get the information of available scenes for
further info generation.
Args:
nusc (class): Dataset class in the nuScenes dataset.
Returns:
available_scenes (list[dict]): List of basic information for the
available scenes.
"""
available_scenes
=
[]
print
(
'total scene num: {}'
.
format
(
len
(
nusc
.
scene
)))
for
scene
in
nusc
.
scene
:
scene_token
=
scene
[
'token'
]
scene_rec
=
nusc
.
get
(
'scene'
,
scene_token
)
sample_rec
=
nusc
.
get
(
'sample'
,
scene_rec
[
'first_sample_token'
])
sd_rec
=
nusc
.
get
(
'sample_data'
,
sample_rec
[
'data'
][
'LIDAR_TOP'
])
has_more_frames
=
True
scene_not_exist
=
False
while
has_more_frames
:
lidar_path
,
boxes
,
_
=
nusc
.
get_sample_data
(
sd_rec
[
'token'
])
lidar_path
=
str
(
lidar_path
)
if
os
.
getcwd
()
in
lidar_path
:
# path from lyftdataset is absolute path
lidar_path
=
lidar_path
.
split
(
f
'
{
os
.
getcwd
()
}
/'
)[
-
1
]
# relative path
if
not
mmcv
.
is_filepath
(
lidar_path
):
scene_not_exist
=
True
break
else
:
break
if
scene_not_exist
:
continue
available_scenes
.
append
(
scene
)
print
(
'exist scene num: {}'
.
format
(
len
(
available_scenes
)))
return
available_scenes
def
_get_can_bus_info
(
nusc
,
nusc_can_bus
,
sample
):
scene_name
=
nusc
.
get
(
'scene'
,
sample
[
'scene_token'
])[
'name'
]
sample_timestamp
=
sample
[
'timestamp'
]
try
:
pose_list
=
nusc_can_bus
.
get_messages
(
scene_name
,
'pose'
)
except
:
return
np
.
zeros
(
18
)
# server scenes do not have can bus information.
can_bus
=
[]
# during each scene, the first timestamp of can_bus may be large than the first sample's timestamp
last_pose
=
pose_list
[
0
]
for
i
,
pose
in
enumerate
(
pose_list
):
if
pose
[
'utime'
]
>
sample_timestamp
:
break
last_pose
=
pose
_
=
last_pose
.
pop
(
'utime'
)
# useless
pos
=
last_pose
.
pop
(
'pos'
)
rotation
=
last_pose
.
pop
(
'orientation'
)
can_bus
.
extend
(
pos
)
can_bus
.
extend
(
rotation
)
for
key
in
last_pose
.
keys
():
can_bus
.
extend
(
pose
[
key
])
# 16 elements
can_bus
.
extend
([
0.
,
0.
])
return
np
.
array
(
can_bus
)
def
obtain_sensor2top
(
nusc
,
sensor_token
,
l2e_t
,
l2e_r_mat
,
e2g_t
,
e2g_r_mat
,
sensor_type
=
'lidar'
):
"""Obtain the info with RT matric from general sensor to Top LiDAR.
Args:
nusc (class): Dataset class in the nuScenes dataset.
sensor_token (str): Sample data token corresponding to the
specific sensor type.
l2e_t (np.ndarray): Translation from lidar to ego in shape (1, 3).
l2e_r_mat (np.ndarray): Rotation matrix from lidar to ego
in shape (3, 3).
e2g_t (np.ndarray): Translation from ego to global in shape (1, 3).
e2g_r_mat (np.ndarray): Rotation matrix from ego to global
in shape (3, 3).
sensor_type (str): Sensor to calibrate. Default: 'lidar'.
Returns:
sweep (dict): Sweep information after transformation.
"""
sd_rec
=
nusc
.
get
(
'sample_data'
,
sensor_token
)
cs_record
=
nusc
.
get
(
'calibrated_sensor'
,
sd_rec
[
'calibrated_sensor_token'
])
pose_record
=
nusc
.
get
(
'ego_pose'
,
sd_rec
[
'ego_pose_token'
])
data_path
=
str
(
nusc
.
get_sample_data_path
(
sd_rec
[
'token'
]))
if
os
.
getcwd
()
in
data_path
:
# path from lyftdataset is absolute path
data_path
=
data_path
.
split
(
f
'
{
os
.
getcwd
()
}
/'
)[
-
1
]
# relative path
sweep
=
{
'data_path'
:
data_path
,
'type'
:
sensor_type
,
'sample_data_token'
:
sd_rec
[
'token'
],
'sensor2ego_translation'
:
cs_record
[
'translation'
],
'sensor2ego_rotation'
:
cs_record
[
'rotation'
],
'ego2global_translation'
:
pose_record
[
'translation'
],
'ego2global_rotation'
:
pose_record
[
'rotation'
],
'timestamp'
:
sd_rec
[
'timestamp'
]
}
l2e_r_s
=
sweep
[
'sensor2ego_rotation'
]
l2e_t_s
=
sweep
[
'sensor2ego_translation'
]
e2g_r_s
=
sweep
[
'ego2global_rotation'
]
e2g_t_s
=
sweep
[
'ego2global_translation'
]
# obtain the RT from sensor to Top LiDAR
# sweep->ego->global->ego'->lidar
l2e_r_s_mat
=
Quaternion
(
l2e_r_s
).
rotation_matrix
e2g_r_s_mat
=
Quaternion
(
e2g_r_s
).
rotation_matrix
R
=
(
l2e_r_s_mat
.
T
@
e2g_r_s_mat
.
T
)
@
(
np
.
linalg
.
inv
(
e2g_r_mat
).
T
@
np
.
linalg
.
inv
(
l2e_r_mat
).
T
)
T
=
(
l2e_t_s
@
e2g_r_s_mat
.
T
+
e2g_t_s
)
@
(
np
.
linalg
.
inv
(
e2g_r_mat
).
T
@
np
.
linalg
.
inv
(
l2e_r_mat
).
T
)
T
-=
e2g_t
@
(
np
.
linalg
.
inv
(
e2g_r_mat
).
T
@
np
.
linalg
.
inv
(
l2e_r_mat
).
T
)
+
l2e_t
@
np
.
linalg
.
inv
(
l2e_r_mat
).
T
sweep
[
'sensor2lidar_rotation'
]
=
R
.
T
# points @ R.T + T
sweep
[
'sensor2lidar_translation'
]
=
T
return
sweep
def
_fill_trainval_infos
(
nusc
,
nusc_can_bus
,
nusc_maps
,
map_explorer
,
train_scenes
,
val_scenes
,
test
=
False
,
max_sweeps
=
10
,
point_cloud_range
=
[
-
15.0
,
-
30.0
,
-
10.0
,
15.0
,
30.0
,
10.0
]):
"""Generate the train/val infos from the raw data.
Args:
nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
train_scenes (list[str]): Basic information of training scenes.
val_scenes (list[str]): Basic information of validation scenes.
test (bool): Whether use the test mode. In the test mode, no
annotations can be accessed. Default: False.
max_sweeps (int): Max number of sweeps. Default: 10.
Returns:
tuple[list[dict]]: Information of training set and validation set
that will be saved to the info file.
"""
train_nusc_infos
=
[]
val_nusc_infos
=
[]
frame_idx
=
0
for
sample
in
mmcv
.
track_iter_progress
(
nusc
.
sample
):
map_location
=
nusc
.
get
(
'log'
,
nusc
.
get
(
'scene'
,
sample
[
'scene_token'
])[
'log_token'
])[
'location'
]
lidar_token
=
sample
[
'data'
][
'LIDAR_TOP'
]
sd_rec
=
nusc
.
get
(
'sample_data'
,
sample
[
'data'
][
'LIDAR_TOP'
])
cs_record
=
nusc
.
get
(
'calibrated_sensor'
,
sd_rec
[
'calibrated_sensor_token'
])
pose_record
=
nusc
.
get
(
'ego_pose'
,
sd_rec
[
'ego_pose_token'
])
lidar_path
,
boxes
,
_
=
nusc
.
get_sample_data
(
lidar_token
)
mmcv
.
check_file_exist
(
lidar_path
)
can_bus
=
_get_can_bus_info
(
nusc
,
nusc_can_bus
,
sample
)
##
info
=
{
'lidar_path'
:
lidar_path
,
'token'
:
sample
[
'token'
],
'prev'
:
sample
[
'prev'
],
'next'
:
sample
[
'next'
],
'can_bus'
:
can_bus
,
'frame_idx'
:
frame_idx
,
# temporal related info
'sweeps'
:
[],
'cams'
:
dict
(),
'map_location'
:
map_location
,
'scene_token'
:
sample
[
'scene_token'
],
# temporal related info
'lidar2ego_translation'
:
cs_record
[
'translation'
],
'lidar2ego_rotation'
:
cs_record
[
'rotation'
],
'ego2global_translation'
:
pose_record
[
'translation'
],
'ego2global_rotation'
:
pose_record
[
'rotation'
],
'timestamp'
:
sample
[
'timestamp'
],
}
if
sample
[
'next'
]
==
''
:
frame_idx
=
0
else
:
frame_idx
+=
1
l2e_r
=
info
[
'lidar2ego_rotation'
]
l2e_t
=
info
[
'lidar2ego_translation'
]
e2g_r
=
info
[
'ego2global_rotation'
]
e2g_t
=
info
[
'ego2global_translation'
]
l2e_r_mat
=
Quaternion
(
l2e_r
).
rotation_matrix
e2g_r_mat
=
Quaternion
(
e2g_r
).
rotation_matrix
# obtain 6 image's information per frame
camera_types
=
[
'CAM_FRONT'
,
'CAM_FRONT_RIGHT'
,
'CAM_FRONT_LEFT'
,
'CAM_BACK'
,
'CAM_BACK_LEFT'
,
'CAM_BACK_RIGHT'
,
]
for
cam
in
camera_types
:
cam_token
=
sample
[
'data'
][
cam
]
cam_path
,
_
,
cam_intrinsic
=
nusc
.
get_sample_data
(
cam_token
)
cam_info
=
obtain_sensor2top
(
nusc
,
cam_token
,
l2e_t
,
l2e_r_mat
,
e2g_t
,
e2g_r_mat
,
cam
)
cam_info
.
update
(
cam_intrinsic
=
cam_intrinsic
)
info
[
'cams'
].
update
({
cam
:
cam_info
})
# obtain sweeps for a single key-frame
sd_rec
=
nusc
.
get
(
'sample_data'
,
sample
[
'data'
][
'LIDAR_TOP'
])
sweeps
=
[]
while
len
(
sweeps
)
<
max_sweeps
:
if
not
sd_rec
[
'prev'
]
==
''
:
sweep
=
obtain_sensor2top
(
nusc
,
sd_rec
[
'prev'
],
l2e_t
,
l2e_r_mat
,
e2g_t
,
e2g_r_mat
,
'lidar'
)
sweeps
.
append
(
sweep
)
sd_rec
=
nusc
.
get
(
'sample_data'
,
sd_rec
[
'prev'
])
else
:
break
info
[
'sweeps'
]
=
sweeps
# obtain annotation
# import ipdb;ipdb.set_trace()
info
=
obtain_vectormap
(
nusc_maps
,
map_explorer
,
info
,
point_cloud_range
)
if
sample
[
'scene_token'
]
in
train_scenes
:
train_nusc_infos
.
append
(
info
)
else
:
val_nusc_infos
.
append
(
info
)
return
train_nusc_infos
,
val_nusc_infos
def
obtain_vectormap
(
nusc_maps
,
map_explorer
,
info
,
point_cloud_range
):
# import ipdb;ipdb.set_trace()
lidar2ego
=
np
.
eye
(
4
)
lidar2ego
[:
3
,:
3
]
=
Quaternion
(
info
[
'lidar2ego_rotation'
]).
rotation_matrix
lidar2ego
[:
3
,
3
]
=
info
[
'lidar2ego_translation'
]
ego2global
=
np
.
eye
(
4
)
ego2global
[:
3
,:
3
]
=
Quaternion
(
info
[
'ego2global_rotation'
]).
rotation_matrix
ego2global
[:
3
,
3
]
=
info
[
'ego2global_translation'
]
lidar2global
=
ego2global
@
lidar2ego
lidar2global_translation
=
list
(
lidar2global
[:
3
,
3
])
lidar2global_rotation
=
list
(
Quaternion
(
matrix
=
lidar2global
).
q
)
location
=
info
[
'map_location'
]
ego2global_translation
=
info
[
'ego2global_translation'
]
ego2global_rotation
=
info
[
'ego2global_rotation'
]
patch_h
=
point_cloud_range
[
4
]
-
point_cloud_range
[
1
]
patch_w
=
point_cloud_range
[
3
]
-
point_cloud_range
[
0
]
patch_size
=
(
patch_h
,
patch_w
)
vector_map
=
VectorizedLocalMap
(
nusc_maps
[
location
],
map_explorer
[
location
],
patch_size
)
map_anns
=
vector_map
.
gen_vectorized_samples
(
lidar2global_translation
,
lidar2global_rotation
)
# import ipdb;ipdb.set_trace()
info
[
"annotation"
]
=
map_anns
return
info
class
VectorizedLocalMap
(
object
):
CLASS2LABEL
=
{
'road_divider'
:
0
,
'lane_divider'
:
0
,
'ped_crossing'
:
1
,
'contours'
:
2
,
'others'
:
-
1
}
def
__init__
(
self
,
nusc_map
,
map_explorer
,
patch_size
,
map_classes
=
[
'divider'
,
'ped_crossing'
,
'boundary'
,
'centerline'
],
line_classes
=
[
'road_divider'
,
'lane_divider'
],
ped_crossing_classes
=
[
'ped_crossing'
],
contour_classes
=
[
'road_segment'
,
'lane'
],
centerline_classes
=
[
'lane_connector'
,
'lane'
],
use_simplify
=
True
,
):
super
().
__init__
()
self
.
nusc_map
=
nusc_map
self
.
map_explorer
=
map_explorer
self
.
vec_classes
=
map_classes
self
.
line_classes
=
line_classes
self
.
ped_crossing_classes
=
ped_crossing_classes
self
.
polygon_classes
=
contour_classes
self
.
centerline_classes
=
centerline_classes
self
.
patch_size
=
patch_size
def
gen_vectorized_samples
(
self
,
lidar2global_translation
,
lidar2global_rotation
):
'''
use lidar2global to get gt map layers
'''
map_pose
=
lidar2global_translation
[:
2
]
rotation
=
Quaternion
(
lidar2global_rotation
)
# import ipdb;ipdb.set_trace()
patch_box
=
(
map_pose
[
0
],
map_pose
[
1
],
self
.
patch_size
[
0
],
self
.
patch_size
[
1
])
patch_angle
=
quaternion_yaw
(
rotation
)
/
np
.
pi
*
180
map_dict
=
{
'divider'
:[],
'ped_crossing'
:[],
'boundary'
:[],
'centerline'
:[]}
vectors
=
[]
for
vec_class
in
self
.
vec_classes
:
if
vec_class
==
'divider'
:
line_geom
=
self
.
get_map_geom
(
patch_box
,
patch_angle
,
self
.
line_classes
)
line_instances_dict
=
self
.
line_geoms_to_instances
(
line_geom
)
for
line_type
,
instances
in
line_instances_dict
.
items
():
for
instance
in
instances
:
map_dict
[
vec_class
].
append
(
np
.
array
(
instance
.
coords
))
# vectors.append((instance, self.CLASS2LABEL.get(line_type, -1)))
elif
vec_class
==
'ped_crossing'
:
ped_geom
=
self
.
get_map_geom
(
patch_box
,
patch_angle
,
self
.
ped_crossing_classes
)
ped_instance_list
=
self
.
ped_poly_geoms_to_instances
(
ped_geom
)
for
instance
in
ped_instance_list
:
# vectors.append((instance, self.CLASS2LABEL.get('ped_crossing', -1)))
map_dict
[
vec_class
].
append
(
np
.
array
(
instance
.
coords
))
elif
vec_class
==
'boundary'
:
polygon_geom
=
self
.
get_map_geom
(
patch_box
,
patch_angle
,
self
.
polygon_classes
)
poly_bound_list
=
self
.
poly_geoms_to_instances
(
polygon_geom
)
for
instance
in
poly_bound_list
:
# import ipdb;ipdb.set_trace()
map_dict
[
vec_class
].
append
(
np
.
array
(
instance
.
coords
))
# vectors.append((contour, self.CLASS2LABEL.get('contours', -1)))
elif
vec_class
==
'centerline'
:
centerline_geom
=
self
.
get_centerline_geom
(
patch_box
,
patch_angle
,
self
.
centerline_classes
)
centerline_list
=
self
.
centerline_geoms_to_instances
(
centerline_geom
)
for
instance
in
centerline_list
:
map_dict
[
vec_class
].
append
(
np
.
array
(
instance
.
coords
))
else
:
raise
ValueError
(
f
'WRONG vec_class:
{
vec_class
}
'
)
# import ipdb;ipdb.set_trace()
return
map_dict
def
get_centerline_geom
(
self
,
patch_box
,
patch_angle
,
layer_names
):
map_geom
=
{}
for
layer_name
in
layer_names
:
if
layer_name
in
self
.
centerline_classes
:
return_token
=
False
layer_centerline_dict
=
self
.
map_explorer
.
_get_centerline
(
patch_box
,
patch_angle
,
layer_name
,
return_token
=
return_token
)
if
len
(
layer_centerline_dict
.
keys
())
==
0
:
continue
# import ipdb;ipdb.set_trace()
map_geom
.
update
(
layer_centerline_dict
)
return
map_geom
def
get_map_geom
(
self
,
patch_box
,
patch_angle
,
layer_names
):
map_geom
=
{}
for
layer_name
in
layer_names
:
if
layer_name
in
self
.
line_classes
:
geoms
=
self
.
get_divider_line
(
patch_box
,
patch_angle
,
layer_name
)
# map_geom.append((layer_name, geoms))
map_geom
[
layer_name
]
=
geoms
elif
layer_name
in
self
.
polygon_classes
:
geoms
=
self
.
get_contour_line
(
patch_box
,
patch_angle
,
layer_name
)
# map_geom.append((layer_name, geoms))
map_geom
[
layer_name
]
=
geoms
elif
layer_name
in
self
.
ped_crossing_classes
:
geoms
=
self
.
get_ped_crossing_line
(
patch_box
,
patch_angle
)
# map_geom.append((layer_name, geoms))
map_geom
[
layer_name
]
=
geoms
return
map_geom
def
get_divider_line
(
self
,
patch_box
,
patch_angle
,
layer_name
):
if
layer_name
not
in
self
.
map_explorer
.
map_api
.
non_geometric_line_layers
:
raise
ValueError
(
"{} is not a line layer"
.
format
(
layer_name
))
if
layer_name
==
'traffic_light'
:
return
None
patch_x
=
patch_box
[
0
]
patch_y
=
patch_box
[
1
]
patch
=
self
.
map_explorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
line_list
=
[]
records
=
getattr
(
self
.
map_explorer
.
map_api
,
layer_name
)
for
record
in
records
:
line
=
self
.
map_explorer
.
map_api
.
extract_line
(
record
[
'line_token'
])
if
line
.
is_empty
:
# Skip lines without nodes.
continue
new_line
=
line
.
intersection
(
patch
)
if
not
new_line
.
is_empty
:
new_line
=
affinity
.
rotate
(
new_line
,
-
patch_angle
,
origin
=
(
patch_x
,
patch_y
),
use_radians
=
False
)
new_line
=
affinity
.
affine_transform
(
new_line
,
[
1.0
,
0.0
,
0.0
,
1.0
,
-
patch_x
,
-
patch_y
])
line_list
.
append
(
new_line
)
return
line_list
def
get_contour_line
(
self
,
patch_box
,
patch_angle
,
layer_name
):
if
layer_name
not
in
self
.
map_explorer
.
map_api
.
non_geometric_polygon_layers
:
raise
ValueError
(
'{} is not a polygonal layer'
.
format
(
layer_name
))
patch_x
=
patch_box
[
0
]
patch_y
=
patch_box
[
1
]
patch
=
self
.
map_explorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
records
=
getattr
(
self
.
map_explorer
.
map_api
,
layer_name
)
polygon_list
=
[]
if
layer_name
==
'drivable_area'
:
for
record
in
records
:
polygons
=
[
self
.
map_explorer
.
map_api
.
extract_polygon
(
polygon_token
)
for
polygon_token
in
record
[
'polygon_tokens'
]]
for
polygon
in
polygons
:
new_polygon
=
polygon
.
intersection
(
patch
)
if
not
new_polygon
.
is_empty
:
new_polygon
=
affinity
.
rotate
(
new_polygon
,
-
patch_angle
,
origin
=
(
patch_x
,
patch_y
),
use_radians
=
False
)
new_polygon
=
affinity
.
affine_transform
(
new_polygon
,
[
1.0
,
0.0
,
0.0
,
1.0
,
-
patch_x
,
-
patch_y
])
if
new_polygon
.
geom_type
==
'Polygon'
:
new_polygon
=
MultiPolygon
([
new_polygon
])
polygon_list
.
append
(
new_polygon
)
else
:
for
record
in
records
:
polygon
=
self
.
map_explorer
.
map_api
.
extract_polygon
(
record
[
'polygon_token'
])
if
polygon
.
is_valid
:
new_polygon
=
polygon
.
intersection
(
patch
)
if
not
new_polygon
.
is_empty
:
new_polygon
=
affinity
.
rotate
(
new_polygon
,
-
patch_angle
,
origin
=
(
patch_x
,
patch_y
),
use_radians
=
False
)
new_polygon
=
affinity
.
affine_transform
(
new_polygon
,
[
1.0
,
0.0
,
0.0
,
1.0
,
-
patch_x
,
-
patch_y
])
if
new_polygon
.
geom_type
==
'Polygon'
:
new_polygon
=
MultiPolygon
([
new_polygon
])
polygon_list
.
append
(
new_polygon
)
return
polygon_list
def
get_ped_crossing_line
(
self
,
patch_box
,
patch_angle
):
patch_x
=
patch_box
[
0
]
patch_y
=
patch_box
[
1
]
patch
=
self
.
map_explorer
.
get_patch_coord
(
patch_box
,
patch_angle
)
polygon_list
=
[]
records
=
getattr
(
self
.
map_explorer
.
map_api
,
'ped_crossing'
)
# records = getattr(self.nusc_maps[location], 'ped_crossing')
for
record
in
records
:
polygon
=
self
.
map_explorer
.
map_api
.
extract_polygon
(
record
[
'polygon_token'
])
if
polygon
.
is_valid
:
new_polygon
=
polygon
.
intersection
(
patch
)
if
not
new_polygon
.
is_empty
:
new_polygon
=
affinity
.
rotate
(
new_polygon
,
-
patch_angle
,
origin
=
(
patch_x
,
patch_y
),
use_radians
=
False
)
new_polygon
=
affinity
.
affine_transform
(
new_polygon
,
[
1.0
,
0.0
,
0.0
,
1.0
,
-
patch_x
,
-
patch_y
])
if
new_polygon
.
geom_type
==
'Polygon'
:
new_polygon
=
MultiPolygon
([
new_polygon
])
polygon_list
.
append
(
new_polygon
)
return
polygon_list
def
line_geoms_to_instances
(
self
,
line_geom
):
line_instances_dict
=
dict
()
for
line_type
,
a_type_of_lines
in
line_geom
.
items
():
one_type_instances
=
self
.
_one_type_line_geom_to_instances
(
a_type_of_lines
)
line_instances_dict
[
line_type
]
=
one_type_instances
return
line_instances_dict
def
_one_type_line_geom_to_instances
(
self
,
line_geom
):
line_instances
=
[]
for
line
in
line_geom
:
if
not
line
.
is_empty
:
if
line
.
geom_type
==
'MultiLineString'
:
for
single_line
in
line
.
geoms
:
line_instances
.
append
(
single_line
)
elif
line
.
geom_type
==
'LineString'
:
line_instances
.
append
(
line
)
else
:
raise
NotImplementedError
return
line_instances
def
ped_poly_geoms_to_instances
(
self
,
ped_geom
):
# ped = ped_geom[0][1]
# import ipdb;ipdb.set_trace()
ped
=
ped_geom
[
'ped_crossing'
]
union_segments
=
ops
.
unary_union
(
ped
)
max_x
=
self
.
patch_size
[
1
]
/
2
max_y
=
self
.
patch_size
[
0
]
/
2
local_patch
=
box
(
-
max_x
-
0.2
,
-
max_y
-
0.2
,
max_x
+
0.2
,
max_y
+
0.2
)
exteriors
=
[]
interiors
=
[]
if
union_segments
.
geom_type
!=
'MultiPolygon'
:
union_segments
=
MultiPolygon
([
union_segments
])
for
poly
in
union_segments
.
geoms
:
exteriors
.
append
(
poly
.
exterior
)
for
inter
in
poly
.
interiors
:
interiors
.
append
(
inter
)
results
=
[]
for
ext
in
exteriors
:
if
ext
.
is_ccw
:
ext
.
coords
=
list
(
ext
.
coords
)[::
-
1
]
lines
=
ext
.
intersection
(
local_patch
)
if
isinstance
(
lines
,
MultiLineString
):
lines
=
ops
.
linemerge
(
lines
)
results
.
append
(
lines
)
for
inter
in
interiors
:
if
not
inter
.
is_ccw
:
inter
.
coords
=
list
(
inter
.
coords
)[::
-
1
]
lines
=
inter
.
intersection
(
local_patch
)
if
isinstance
(
lines
,
MultiLineString
):
lines
=
ops
.
linemerge
(
lines
)
results
.
append
(
lines
)
return
self
.
_one_type_line_geom_to_instances
(
results
)
def
poly_geoms_to_instances
(
self
,
polygon_geom
):
roads
=
polygon_geom
[
'road_segment'
]
lanes
=
polygon_geom
[
'lane'
]
# import ipdb;ipdb.set_trace()
union_roads
=
ops
.
unary_union
(
roads
)
union_lanes
=
ops
.
unary_union
(
lanes
)
union_segments
=
ops
.
unary_union
([
union_roads
,
union_lanes
])
max_x
=
self
.
patch_size
[
1
]
/
2
max_y
=
self
.
patch_size
[
0
]
/
2
local_patch
=
box
(
-
max_x
+
0.2
,
-
max_y
+
0.2
,
max_x
-
0.2
,
max_y
-
0.2
)
exteriors
=
[]
interiors
=
[]
if
union_segments
.
geom_type
!=
'MultiPolygon'
:
union_segments
=
MultiPolygon
([
union_segments
])
for
poly
in
union_segments
.
geoms
:
exteriors
.
append
(
poly
.
exterior
)
for
inter
in
poly
.
interiors
:
interiors
.
append
(
inter
)
results
=
[]
for
ext
in
exteriors
:
if
ext
.
is_ccw
:
ext
.
coords
=
list
(
ext
.
coords
)[::
-
1
]
lines
=
ext
.
intersection
(
local_patch
)
if
isinstance
(
lines
,
MultiLineString
):
lines
=
ops
.
linemerge
(
lines
)
results
.
append
(
lines
)
for
inter
in
interiors
:
if
not
inter
.
is_ccw
:
inter
.
coords
=
list
(
inter
.
coords
)[::
-
1
]
lines
=
inter
.
intersection
(
local_patch
)
if
isinstance
(
lines
,
MultiLineString
):
lines
=
ops
.
linemerge
(
lines
)
results
.
append
(
lines
)
return
self
.
_one_type_line_geom_to_instances
(
results
)
def
centerline_geoms_to_instances
(
self
,
geoms_dict
):
centerline_geoms_list
,
pts_G
=
self
.
union_centerline
(
geoms_dict
)
# vectors_dict = self.centerline_geoms2vec(centerline_geoms_list)
# import ipdb;ipdb.set_trace()
return
self
.
_one_type_line_geom_to_instances
(
centerline_geoms_list
)
def
centerline_geoms2vec
(
self
,
centerline_geoms_list
):
vector_dict
=
{}
# import ipdb;ipdb.set_trace()
# centerline_geoms_list = [line.simplify(0.2, preserve_topology=True) \
# for line in centerline_geoms_list]
vectors
=
self
.
_geom_to_vectors
(
centerline_geoms_list
)
vector_dict
.
update
({
'centerline'
:
(
'centerline'
,
vectors
)})
return
vector_dict
def
union_centerline
(
self
,
centerline_geoms
):
# import ipdb;ipdb.set_trace()
pts_G
=
nx
.
DiGraph
()
junction_pts_list
=
[]
for
key
,
value
in
centerline_geoms
.
items
():
centerline_geom
=
value
[
'centerline'
]
if
centerline_geom
.
geom_type
==
'MultiLineString'
:
start_pt
=
np
.
array
(
centerline_geom
.
geoms
[
0
].
coords
).
round
(
3
)[
0
]
end_pt
=
np
.
array
(
centerline_geom
.
geoms
[
-
1
].
coords
).
round
(
3
)[
-
1
]
for
single_geom
in
centerline_geom
.
geoms
:
single_geom_pts
=
np
.
array
(
single_geom
.
coords
).
round
(
3
)
for
idx
,
pt
in
enumerate
(
single_geom_pts
[:
-
1
]):
pts_G
.
add_edge
(
tuple
(
single_geom_pts
[
idx
]),
tuple
(
single_geom_pts
[
idx
+
1
]))
elif
centerline_geom
.
geom_type
==
'LineString'
:
centerline_pts
=
np
.
array
(
centerline_geom
.
coords
).
round
(
3
)
start_pt
=
centerline_pts
[
0
]
end_pt
=
centerline_pts
[
-
1
]
for
idx
,
pts
in
enumerate
(
centerline_pts
[:
-
1
]):
pts_G
.
add_edge
(
tuple
(
centerline_pts
[
idx
]),
tuple
(
centerline_pts
[
idx
+
1
]))
else
:
raise
NotImplementedError
valid_incoming_num
=
0
for
idx
,
pred
in
enumerate
(
value
[
'incoming_tokens'
]):
if
pred
in
centerline_geoms
.
keys
():
valid_incoming_num
+=
1
pred_geom
=
centerline_geoms
[
pred
][
'centerline'
]
if
pred_geom
.
geom_type
==
'MultiLineString'
:
pred_pt
=
np
.
array
(
pred_geom
.
geoms
[
-
1
].
coords
).
round
(
3
)[
-
1
]
# if pred_pt != centerline_pts[0]:
pts_G
.
add_edge
(
tuple
(
pred_pt
),
tuple
(
start_pt
))
else
:
pred_pt
=
np
.
array
(
pred_geom
.
coords
).
round
(
3
)[
-
1
]
pts_G
.
add_edge
(
tuple
(
pred_pt
),
tuple
(
start_pt
))
if
valid_incoming_num
>
1
:
junction_pts_list
.
append
(
tuple
(
start_pt
))
valid_outgoing_num
=
0
for
idx
,
succ
in
enumerate
(
value
[
'outgoing_tokens'
]):
if
succ
in
centerline_geoms
.
keys
():
valid_outgoing_num
+=
1
succ_geom
=
centerline_geoms
[
succ
][
'centerline'
]
if
succ_geom
.
geom_type
==
'MultiLineString'
:
succ_pt
=
np
.
array
(
succ_geom
.
geoms
[
0
].
coords
).
round
(
3
)[
0
]
# if pred_pt != centerline_pts[0]:
pts_G
.
add_edge
(
tuple
(
end_pt
),
tuple
(
succ_pt
))
else
:
succ_pt
=
np
.
array
(
succ_geom
.
coords
).
round
(
3
)[
0
]
pts_G
.
add_edge
(
tuple
(
end_pt
),
tuple
(
succ_pt
))
if
valid_outgoing_num
>
1
:
junction_pts_list
.
append
(
tuple
(
end_pt
))
roots
=
(
v
for
v
,
d
in
pts_G
.
in_degree
()
if
d
==
0
)
leaves
=
[
v
for
v
,
d
in
pts_G
.
out_degree
()
if
d
==
0
]
all_paths
=
[]
for
root
in
roots
:
paths
=
nx
.
all_simple_paths
(
pts_G
,
root
,
leaves
)
all_paths
.
extend
(
paths
)
final_centerline_paths
=
[]
for
path
in
all_paths
:
merged_line
=
LineString
(
path
)
merged_line
=
merged_line
.
simplify
(
0.2
,
preserve_topology
=
True
)
final_centerline_paths
.
append
(
merged_line
)
return
final_centerline_paths
,
pts_G
def
create_nuscenes_infos
(
root_path
,
out_path
,
can_bus_root_path
,
info_prefix
,
version
=
'v1.0-trainval'
,
max_sweeps
=
10
):
"""Create info file of nuscene dataset.
Given the raw data, generate its related info file in pkl format.
Args:
root_path (str): Path of the data root.
info_prefix (str): Prefix of the info file to be generated.
version (str): Version of the data.
Default: 'v1.0-trainval'
max_sweeps (int): Max number of sweeps.
Default: 10
"""
from
nuscenes.nuscenes
import
NuScenes
from
nuscenes.can_bus.can_bus_api
import
NuScenesCanBus
print
(
version
,
root_path
)
nusc
=
NuScenes
(
version
=
version
,
dataroot
=
root_path
,
verbose
=
True
)
nusc_can_bus
=
NuScenesCanBus
(
dataroot
=
can_bus_root_path
)
MAPS
=
[
'boston-seaport'
,
'singapore-hollandvillage'
,
'singapore-onenorth'
,
'singapore-queenstown'
]
nusc_maps
=
{}
map_explorer
=
{}
for
loc
in
MAPS
:
nusc_maps
[
loc
]
=
NuScenesMap
(
dataroot
=
root_path
,
map_name
=
loc
)
map_explorer
[
loc
]
=
CNuScenesMapExplorer
(
nusc_maps
[
loc
])
from
nuscenes.utils
import
splits
available_vers
=
[
'v1.0-trainval'
,
'v1.0-test'
,
'v1.0-mini'
]
assert
version
in
available_vers
if
version
==
'v1.0-trainval'
:
train_scenes
=
splits
.
train
val_scenes
=
splits
.
val
elif
version
==
'v1.0-test'
:
train_scenes
=
splits
.
test
val_scenes
=
[]
elif
version
==
'v1.0-mini'
:
train_scenes
=
splits
.
mini_train
val_scenes
=
splits
.
mini_val
else
:
raise
ValueError
(
'unknown'
)
# filter existing scenes.
available_scenes
=
get_available_scenes
(
nusc
)
available_scene_names
=
[
s
[
'name'
]
for
s
in
available_scenes
]
train_scenes
=
list
(
filter
(
lambda
x
:
x
in
available_scene_names
,
train_scenes
))
val_scenes
=
list
(
filter
(
lambda
x
:
x
in
available_scene_names
,
val_scenes
))
train_scenes
=
set
([
available_scenes
[
available_scene_names
.
index
(
s
)][
'token'
]
for
s
in
train_scenes
])
val_scenes
=
set
([
available_scenes
[
available_scene_names
.
index
(
s
)][
'token'
]
for
s
in
val_scenes
])
test
=
'test'
in
version
if
test
:
print
(
'test scene: {}'
.
format
(
len
(
train_scenes
)))
else
:
print
(
'train scene: {}, val scene: {}'
.
format
(
len
(
train_scenes
),
len
(
val_scenes
)))
train_nusc_infos
,
val_nusc_infos
=
_fill_trainval_infos
(
nusc
,
nusc_can_bus
,
nusc_maps
,
map_explorer
,
train_scenes
,
val_scenes
,
test
,
max_sweeps
=
max_sweeps
)
metadata
=
dict
(
version
=
version
)
if
test
:
print
(
'test sample: {}'
.
format
(
len
(
train_nusc_infos
)))
data
=
dict
(
infos
=
train_nusc_infos
,
metadata
=
metadata
)
info_path
=
osp
.
join
(
out_path
,
'{}_map_infos_temporal_test.pkl'
.
format
(
info_prefix
))
mmcv
.
dump
(
data
,
info_path
)
else
:
print
(
'train sample: {}, val sample: {}'
.
format
(
len
(
train_nusc_infos
),
len
(
val_nusc_infos
)))
data
=
dict
(
infos
=
train_nusc_infos
,
metadata
=
metadata
)
info_path
=
osp
.
join
(
out_path
,
'{}_map_infos_temporal_train.pkl'
.
format
(
info_prefix
))
mmcv
.
dump
(
data
,
info_path
)
data
[
'infos'
]
=
val_nusc_infos
info_val_path
=
osp
.
join
(
out_path
,
'{}_map_infos_temporal_val.pkl'
.
format
(
info_prefix
))
mmcv
.
dump
(
data
,
info_val_path
)
def
nuscenes_data_prep
(
root_path
,
can_bus_root_path
,
info_prefix
,
version
,
dataset_name
,
out_dir
,
max_sweeps
=
10
):
"""Prepare data related to nuScenes dataset.
Related data consists of '.pkl' files recording basic infos,
2D annotations and groundtruth database.
Args:
root_path (str): Path of dataset root.
info_prefix (str): The prefix of info filenames.
version (str): Dataset version.
dataset_name (str): The dataset class name.
out_dir (str): Output directory of the groundtruth database info.
max_sweeps (int): Number of input consecutive frames. Default: 10
"""
create_nuscenes_infos
(
root_path
,
out_dir
,
can_bus_root_path
,
info_prefix
,
version
=
version
,
max_sweeps
=
max_sweeps
)
# if version == 'v1.0-test':
# info_test_path = osp.join(
# out_dir, f'{info_prefix}_infos_temporal_test.pkl')
# nuscenes_converter.export_2d_annotation(
# root_path, info_test_path, version=version)
# else:
# info_train_path = osp.join(
# out_dir, f'{info_prefix}_infos_temporal_train.pkl')
# info_val_path = osp.join(
# out_dir, f'{info_prefix}_infos_temporal_val.pkl')
# nuscenes_converter.export_2d_annotation(
# root_path, info_train_path, version=version)
# nuscenes_converter.export_2d_annotation(
# root_path, info_val_path, version=version)
# create_groundtruth_database(dataset_name, root_path, info_prefix,
# f'{out_dir}/{info_prefix}_infos_train.pkl')
parser
=
argparse
.
ArgumentParser
(
description
=
'Data converter arg parser'
)
parser
.
add_argument
(
'--root-path'
,
type
=
str
,
default
=
'./data/kitti'
,
help
=
'specify the root path of dataset'
)
parser
.
add_argument
(
'--canbus'
,
type
=
str
,
default
=
'./data'
,
help
=
'specify the root path of nuScenes canbus'
)
parser
.
add_argument
(
'--version'
,
type
=
str
,
default
=
'v1.0'
,
required
=
False
,
help
=
'specify the dataset version, no need for kitti'
)
parser
.
add_argument
(
'--max-sweeps'
,
type
=
int
,
default
=
10
,
required
=
False
,
help
=
'specify sweeps of lidar per example'
)
parser
.
add_argument
(
'--out-dir'
,
type
=
str
,
default
=
'./data/kitti'
,
required
=
'False'
,
help
=
'name of info pkl'
)
parser
.
add_argument
(
'--extra-tag'
,
type
=
str
,
default
=
'nuscenes'
)
parser
.
add_argument
(
'--workers'
,
type
=
int
,
default
=
4
,
help
=
'number of threads to be used'
)
args
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
train_version
=
f
'
{
args
.
version
}
-trainval'
nuscenes_data_prep
(
root_path
=
args
.
root_path
,
can_bus_root_path
=
args
.
canbus
,
info_prefix
=
args
.
extra_tag
,
version
=
train_version
,
dataset_name
=
'NuScenesDataset'
,
out_dir
=
args
.
out_dir
,
max_sweeps
=
args
.
max_sweeps
)
test_version
=
f
'
{
args
.
version
}
-test'
nuscenes_data_prep
(
root_path
=
args
.
root_path
,
can_bus_root_path
=
args
.
canbus
,
info_prefix
=
args
.
extra_tag
,
version
=
test_version
,
dataset_name
=
'NuScenesDataset'
,
out_dir
=
args
.
out_dir
,
max_sweeps
=
args
.
max_sweeps
)
\ No newline at end of file
docker-hub/MapTRv2/MapTR/tools/maptrv2/nusc_vis_pred.py
0 → 100644
View file @
19472568
import
argparse
import
mmcv
import
os
import
shutil
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
)
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.datasets
import
build_dataset
import
sys
sys
.
path
.
append
(
''
)
from
projects.mmdet3d_plugin.datasets.builder
import
build_dataloader
from
mmdet3d.models
import
build_model
from
mmdet.apis
import
set_random_seed
from
projects.mmdet3d_plugin.bevformer.apis.test
import
custom_multi_gpu_test
from
mmdet.datasets
import
replace_ImageToTensor
import
time
import
os.path
as
osp
import
numpy
as
np
from
PIL
import
Image
import
matplotlib.pyplot
as
plt
from
matplotlib
import
transforms
from
matplotlib.patches
import
Rectangle
import
cv2
CAMS
=
[
'CAM_FRONT_LEFT'
,
'CAM_FRONT'
,
'CAM_FRONT_RIGHT'
,
'CAM_BACK_LEFT'
,
'CAM_BACK'
,
'CAM_BACK_RIGHT'
,]
# we choose these samples not because it is easy but because it is hard
CANDIDATE
=
[
'n008-2018-08-01-15-16-36-0400_1533151184047036'
,
'n008-2018-08-01-15-16-36-0400_1533151200646853'
,
'n008-2018-08-01-15-16-36-0400_1533151274047332'
,
'n008-2018-08-01-15-16-36-0400_1533151369947807'
,
'n008-2018-08-01-15-16-36-0400_1533151581047647'
,
'n008-2018-08-01-15-16-36-0400_1533151585447531'
,
'n008-2018-08-01-15-16-36-0400_1533151741547700'
,
'n008-2018-08-01-15-16-36-0400_1533151854947676'
,
'n008-2018-08-22-15-53-49-0400_1534968048946931'
,
'n008-2018-08-22-15-53-49-0400_1534968255947662'
,
'n008-2018-08-01-15-16-36-0400_1533151616447606'
,
'n015-2018-07-18-11-41-49+0800_1531885617949602'
,
'n008-2018-08-28-16-43-51-0400_1535489136547616'
,
'n008-2018-08-28-16-43-51-0400_1535489145446939'
,
'n008-2018-08-28-16-43-51-0400_1535489152948944'
,
'n008-2018-08-28-16-43-51-0400_1535489299547057'
,
'n008-2018-08-28-16-43-51-0400_1535489317946828'
,
'n008-2018-09-18-15-12-01-0400_1537298038950431'
,
'n008-2018-09-18-15-12-01-0400_1537298047650680'
,
'n008-2018-09-18-15-12-01-0400_1537298056450495'
,
'n008-2018-09-18-15-12-01-0400_1537298074700410'
,
'n008-2018-09-18-15-12-01-0400_1537298088148941'
,
'n008-2018-09-18-15-12-01-0400_1537298101700395'
,
'n015-2018-11-21-19-21-35+0800_1542799330198603'
,
'n015-2018-11-21-19-21-35+0800_1542799345696426'
,
'n015-2018-11-21-19-21-35+0800_1542799353697765'
,
'n015-2018-11-21-19-21-35+0800_1542799525447813'
,
'n015-2018-11-21-19-21-35+0800_1542799676697935'
,
'n015-2018-11-21-19-21-35+0800_1542799758948001'
,
]
def
perspective
(
cam_coords
,
proj_mat
):
pix_coords
=
proj_mat
@
cam_coords
valid_idx
=
pix_coords
[
2
,
:]
>
0
pix_coords
=
pix_coords
[:,
valid_idx
]
pix_coords
=
pix_coords
[:
2
,
:]
/
(
pix_coords
[
2
,
:]
+
1e-7
)
pix_coords
=
pix_coords
.
transpose
(
1
,
0
)
return
pix_coords
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'vis hdmaptr map gt label'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--score-thresh'
,
default
=
0.4
,
type
=
float
,
help
=
'samples to visualize'
)
parser
.
add_argument
(
'--show-dir'
,
help
=
'directory where visualizations will be saved'
)
parser
.
add_argument
(
'--show-cam'
,
action
=
'store_true'
,
help
=
'show camera pic'
)
parser
.
add_argument
(
'--gt-format'
,
type
=
str
,
nargs
=
'+'
,
default
=
[
'fixed_num_pts'
,],
help
=
'vis format, default should be "points",'
'support ["se_pts","bbox","fixed_num_pts","polyline_pts"]'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
# import modules from plguin/xx, registry will be updated
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
plugin_dir
=
cfg
.
plugin_dir
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
cfg
.
model
.
pretrained
=
None
# in case the test dataset is concatenated
samples_per_gpu
=
1
if
isinstance
(
cfg
.
data
.
test
,
dict
):
cfg
.
data
.
test
.
test_mode
=
True
samples_per_gpu
=
cfg
.
data
.
test
.
pop
(
'samples_per_gpu'
,
1
)
if
samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
test
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
test
.
pipeline
)
elif
isinstance
(
cfg
.
data
.
test
,
list
):
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
test_mode
=
True
samples_per_gpu
=
max
(
[
ds_cfg
.
pop
(
'samples_per_gpu'
,
1
)
for
ds_cfg
in
cfg
.
data
.
test
])
if
samples_per_gpu
>
1
:
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
pipeline
=
replace_ImageToTensor
(
ds_cfg
.
pipeline
)
if
args
.
show_dir
is
None
:
args
.
show_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
],
'vis_pred'
)
# create vis_label dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
args
.
show_dir
))
cfg
.
dump
(
osp
.
join
(
args
.
show_dir
,
osp
.
basename
(
args
.
config
)))
logger
=
get_root_logger
()
logger
.
info
(
f
'DONE create vis_pred dir:
{
args
.
show_dir
}
'
)
dataset
=
build_dataset
(
cfg
.
data
.
test
)
dataset
.
is_vis_on_test
=
True
#TODO, this is a hack
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
samples_per_gpu
,
# workers_per_gpu=cfg.data.workers_per_gpu,
workers_per_gpu
=
0
,
dist
=
False
,
shuffle
=
False
,
nonshuffler_sampler
=
cfg
.
data
.
nonshuffler_sampler
,
)
logger
.
info
(
'Done build test data set'
)
# build the model and load checkpoint
# import pdb;pdb.set_trace()
cfg
.
model
.
train_cfg
=
None
# cfg.model.pts_bbox_head.bbox_coder.max_num=15 # TODO this is a hack
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
logger
.
info
(
'loading check point'
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
if
'CLASSES'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
CLASSES
=
checkpoint
[
'meta'
][
'CLASSES'
]
else
:
model
.
CLASSES
=
dataset
.
CLASSES
# palette for visualization in segmentation tasks
if
'PALETTE'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
PALETTE
=
checkpoint
[
'meta'
][
'PALETTE'
]
elif
hasattr
(
dataset
,
'PALETTE'
):
# segmentation dataset has `PALETTE` attribute
model
.
PALETTE
=
dataset
.
PALETTE
logger
.
info
(
'DONE load check point'
)
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
model
.
eval
()
img_norm_cfg
=
cfg
.
img_norm_cfg
# get denormalized param
mean
=
np
.
array
(
img_norm_cfg
[
'mean'
],
dtype
=
np
.
float32
)
std
=
np
.
array
(
img_norm_cfg
[
'std'
],
dtype
=
np
.
float32
)
to_bgr
=
img_norm_cfg
[
'to_rgb'
]
# get pc_range
pc_range
=
cfg
.
point_cloud_range
# get car icon
car_img
=
Image
.
open
(
'./figs/lidar_car.png'
)
# get color map: divider->r, ped->b, boundary->g
colors_plt
=
[
'orange'
,
'b'
,
'r'
,
'g'
]
logger
.
info
(
'BEGIN vis test dataset samples gt label & pred'
)
bbox_results
=
[]
mask_results
=
[]
dataset
=
data_loader
.
dataset
have_mask
=
False
# prog_bar = mmcv.ProgressBar(len(CANDIDATE))
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
# import pdb;pdb.set_trace()
for
i
,
data
in
enumerate
(
data_loader
):
if
~
(
data
[
'gt_labels_3d'
].
data
[
0
][
0
]
!=
-
1
).
any
():
# import pdb;pdb.set_trace()
logger
.
error
(
f
'
\n
empty gt for index
{
i
}
, continue'
)
# prog_bar.update()
continue
img
=
data
[
'img'
][
0
].
data
[
0
]
img_metas
=
data
[
'img_metas'
][
0
].
data
[
0
]
gt_bboxes_3d
=
data
[
'gt_bboxes_3d'
].
data
[
0
]
gt_labels_3d
=
data
[
'gt_labels_3d'
].
data
[
0
]
pts_filename
=
img_metas
[
0
][
'pts_filename'
]
pts_filename
=
osp
.
basename
(
pts_filename
)
pts_filename
=
pts_filename
.
replace
(
'__LIDAR_TOP__'
,
'_'
).
split
(
'.'
)[
0
]
# import pdb;pdb.set_trace()
# if pts_filename not in CANDIDATE:
# continue
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
sample_dir
=
osp
.
join
(
args
.
show_dir
,
pts_filename
)
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
sample_dir
))
filename_list
=
img_metas
[
0
][
'filename'
]
img_path_dict
=
{}
# save cam img for sample
for
filepath
in
filename_list
:
filename
=
osp
.
basename
(
filepath
)
filename_splits
=
filename
.
split
(
'__'
)
# sample_dir = filename_splits[0]
# sample_dir = osp.join(args.show_dir, sample_dir)
# mmcv.mkdir_or_exist(osp.abspath(sample_dir))
img_name
=
filename_splits
[
1
]
+
'.jpg'
img_path
=
osp
.
join
(
sample_dir
,
img_name
)
# img_path_list.append(img_path)
shutil
.
copyfile
(
filepath
,
img_path
)
img_path_dict
[
filename_splits
[
1
]]
=
img_path
# surrounding view
row_1_list
=
[]
for
cam
in
CAMS
[:
3
]:
cam_img_name
=
cam
+
'.jpg'
cam_img
=
cv2
.
imread
(
osp
.
join
(
sample_dir
,
cam_img_name
))
row_1_list
.
append
(
cam_img
)
row_2_list
=
[]
for
cam
in
CAMS
[
3
:]:
cam_img_name
=
cam
+
'.jpg'
cam_img
=
cv2
.
imread
(
osp
.
join
(
sample_dir
,
cam_img_name
))
row_2_list
.
append
(
cam_img
)
row_1_img
=
cv2
.
hconcat
(
row_1_list
)
row_2_img
=
cv2
.
hconcat
(
row_2_list
)
cams_img
=
cv2
.
vconcat
([
row_1_img
,
row_2_img
])
cams_img_path
=
osp
.
join
(
sample_dir
,
'surroud_view.jpg'
)
cv2
.
imwrite
(
cams_img_path
,
cams_img
,[
cv2
.
IMWRITE_JPEG_QUALITY
,
70
])
for
vis_format
in
args
.
gt_format
:
if
vis_format
==
'se_pts'
:
gt_line_points
=
gt_bboxes_3d
[
0
].
start_end_points
for
gt_bbox_3d
,
gt_label_3d
in
zip
(
gt_line_points
,
gt_labels_3d
[
0
]):
pts
=
gt_bbox_3d
.
reshape
(
-
1
,
2
).
numpy
()
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
plt
.
quiver
(
x
[:
-
1
],
y
[:
-
1
],
x
[
1
:]
-
x
[:
-
1
],
y
[
1
:]
-
y
[:
-
1
],
scale_units
=
'xy'
,
angles
=
'xy'
,
scale
=
1
,
color
=
colors_plt
[
gt_label_3d
])
elif
vis_format
==
'bbox'
:
gt_lines_bbox
=
gt_bboxes_3d
[
0
].
bbox
for
gt_bbox_3d
,
gt_label_3d
in
zip
(
gt_lines_bbox
,
gt_labels_3d
[
0
]):
gt_bbox_3d
=
gt_bbox_3d
.
numpy
()
xy
=
(
gt_bbox_3d
[
0
],
gt_bbox_3d
[
1
])
width
=
gt_bbox_3d
[
2
]
-
gt_bbox_3d
[
0
]
height
=
gt_bbox_3d
[
3
]
-
gt_bbox_3d
[
1
]
# import pdb;pdb.set_trace()
plt
.
gca
().
add_patch
(
Rectangle
(
xy
,
width
,
height
,
linewidth
=
0.4
,
edgecolor
=
colors_plt
[
gt_label_3d
],
facecolor
=
'none'
))
# plt.Rectangle(xy, width, height,color=colors_plt[gt_label_3d])
# continue
elif
vis_format
==
'fixed_num_pts'
:
plt
.
figure
(
figsize
=
(
2
,
4
))
plt
.
xlim
(
pc_range
[
0
],
pc_range
[
3
])
plt
.
ylim
(
pc_range
[
1
],
pc_range
[
4
])
plt
.
axis
(
'off'
)
# gt_bboxes_3d[0].fixed_num=30 #TODO, this is a hack
gt_lines_fixed_num_pts
=
gt_bboxes_3d
[
0
].
fixed_num_sampled_points
for
gt_bbox_3d
,
gt_label_3d
in
zip
(
gt_lines_fixed_num_pts
,
gt_labels_3d
[
0
]):
# import pdb;pdb.set_trace()
pts
=
gt_bbox_3d
.
numpy
()
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], scale_units='xy', angles='xy', scale=1, color=colors_plt[gt_label_3d])
plt
.
plot
(
x
,
y
,
color
=
colors_plt
[
gt_label_3d
],
linewidth
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
plt
.
scatter
(
x
,
y
,
color
=
colors_plt
[
gt_label_3d
],
s
=
2
,
alpha
=
0.8
,
zorder
=-
1
)
# plt.plot(x, y, color=colors_plt[gt_label_3d])
# plt.scatter(x, y, color=colors_plt[gt_label_3d],s=1)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.2
,
1.2
,
-
1.5
,
1.5
])
gt_fixedpts_map_path
=
osp
.
join
(
sample_dir
,
'GT_fixednum_pts_MAP.png'
)
plt
.
savefig
(
gt_fixedpts_map_path
,
bbox_inches
=
'tight'
,
format
=
'png'
,
dpi
=
1200
)
plt
.
close
()
elif
vis_format
==
'polyline_pts'
:
plt
.
figure
(
figsize
=
(
2
,
4
))
plt
.
xlim
(
pc_range
[
0
],
pc_range
[
3
])
plt
.
ylim
(
pc_range
[
1
],
pc_range
[
4
])
plt
.
axis
(
'off'
)
gt_lines_instance
=
gt_bboxes_3d
[
0
].
instance_list
# import pdb;pdb.set_trace()
for
gt_line_instance
,
gt_label_3d
in
zip
(
gt_lines_instance
,
gt_labels_3d
[
0
]):
pts
=
np
.
array
(
list
(
gt_line_instance
.
coords
))
x
=
np
.
array
([
pt
[
0
]
for
pt
in
pts
])
y
=
np
.
array
([
pt
[
1
]
for
pt
in
pts
])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], scale_units='xy', angles='xy', scale=1, color=colors_plt[gt_label_3d])
# plt.plot(x, y, color=colors_plt[gt_label_3d])
plt
.
plot
(
x
,
y
,
color
=
colors_plt
[
gt_label_3d
],
linewidth
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
plt
.
scatter
(
x
,
y
,
color
=
colors_plt
[
gt_label_3d
],
s
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.2
,
1.2
,
-
1.5
,
1.5
])
gt_polyline_map_path
=
osp
.
join
(
sample_dir
,
'GT_polyline_pts_MAP.png'
)
plt
.
savefig
(
gt_polyline_map_path
,
bbox_inches
=
'tight'
,
format
=
'png'
,
dpi
=
1200
)
plt
.
close
()
else
:
logger
.
error
(
f
'WRONG visformat for GT:
{
vis_format
}
'
)
raise
ValueError
(
f
'WRONG visformat for GT:
{
vis_format
}
'
)
# import pdb;pdb.set_trace()
plt
.
figure
(
figsize
=
(
2
,
4
))
plt
.
xlim
(
pc_range
[
0
],
pc_range
[
3
])
plt
.
ylim
(
pc_range
[
1
],
pc_range
[
4
])
plt
.
axis
(
'off'
)
# visualize pred
# import pdb;pdb.set_trace()
result_dic
=
result
[
0
][
'pts_bbox'
]
boxes_3d
=
result_dic
[
'boxes_3d'
]
# bbox: xmin, ymin, xmax, ymax
scores_3d
=
result_dic
[
'scores_3d'
]
labels_3d
=
result_dic
[
'labels_3d'
]
pts_3d
=
result_dic
[
'pts_3d'
]
keep
=
scores_3d
>
args
.
score_thresh
plt
.
figure
(
figsize
=
(
2
,
4
))
plt
.
xlim
(
pc_range
[
0
],
pc_range
[
3
])
plt
.
ylim
(
pc_range
[
1
],
pc_range
[
4
])
plt
.
axis
(
'off'
)
for
pred_score_3d
,
pred_bbox_3d
,
pred_label_3d
,
pred_pts_3d
in
zip
(
scores_3d
[
keep
],
boxes_3d
[
keep
],
labels_3d
[
keep
],
pts_3d
[
keep
]):
pred_pts_3d
=
pred_pts_3d
.
numpy
()
pts_x
=
pred_pts_3d
[:,
0
]
pts_y
=
pred_pts_3d
[:,
1
]
plt
.
plot
(
pts_x
,
pts_y
,
color
=
colors_plt
[
pred_label_3d
],
linewidth
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
plt
.
scatter
(
pts_x
,
pts_y
,
color
=
colors_plt
[
pred_label_3d
],
s
=
1
,
alpha
=
0.8
,
zorder
=-
1
)
pred_bbox_3d
=
pred_bbox_3d
.
numpy
()
xy
=
(
pred_bbox_3d
[
0
],
pred_bbox_3d
[
1
])
width
=
pred_bbox_3d
[
2
]
-
pred_bbox_3d
[
0
]
height
=
pred_bbox_3d
[
3
]
-
pred_bbox_3d
[
1
]
pred_score_3d
=
float
(
pred_score_3d
)
pred_score_3d
=
round
(
pred_score_3d
,
2
)
s
=
str
(
pred_score_3d
)
plt
.
imshow
(
car_img
,
extent
=
[
-
1.2
,
1.2
,
-
1.5
,
1.5
])
map_path
=
osp
.
join
(
sample_dir
,
'PRED_MAP_plot.png'
)
plt
.
savefig
(
map_path
,
bbox_inches
=
'tight'
,
format
=
'png'
,
dpi
=
1200
)
plt
.
close
()
prog_bar
.
update
()
logger
.
info
(
'
\n
DONE vis test dataset samples gt label & pred'
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/misc/browse_dataset.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
numpy
as
np
import
warnings
from
mmcv
import
Config
,
DictAction
,
mkdir_or_exist
,
track_iter_progress
from
os
import
path
as
osp
from
mmdet3d.core.bbox
import
(
Box3DMode
,
CameraInstance3DBoxes
,
Coord3DMode
,
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
)
from
mmdet3d.core.visualizer
import
(
show_multi_modality_result
,
show_result
,
show_seg_result
)
from
mmdet3d.datasets
import
build_dataset
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Browse a dataset'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--skip-type'
,
type
=
str
,
nargs
=
'+'
,
default
=
[
'Normalize'
],
help
=
'skip some useless pipeline'
)
parser
.
add_argument
(
'--output-dir'
,
default
=
None
,
type
=
str
,
help
=
'If there is no display interface, you can save it'
)
parser
.
add_argument
(
'--task'
,
type
=
str
,
choices
=
[
'det'
,
'seg'
,
'multi_modality-det'
,
'mono-det'
],
help
=
'Determine the visualization method depending on the task.'
)
parser
.
add_argument
(
'--online'
,
action
=
'store_true'
,
help
=
'Whether to perform online visualization. Note that you often '
'need a monitor to do so.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
args
=
parser
.
parse_args
()
return
args
def
build_data_cfg
(
config_path
,
skip_type
,
cfg_options
):
"""Build data config for loading visualization data."""
cfg
=
Config
.
fromfile
(
config_path
)
if
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
cfg_options
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# extract inner dataset of `RepeatDataset` as `cfg.data.train`
# so we don't need to worry about it later
if
cfg
.
data
.
train
[
'type'
]
==
'RepeatDataset'
:
cfg
.
data
.
train
=
cfg
.
data
.
train
.
dataset
# use only first dataset for `ConcatDataset`
if
cfg
.
data
.
train
[
'type'
]
==
'ConcatDataset'
:
cfg
.
data
.
train
=
cfg
.
data
.
train
.
datasets
[
0
]
train_data_cfg
=
cfg
.
data
.
train
# eval_pipeline purely consists of loading functions
# use eval_pipeline for data loading
train_data_cfg
[
'pipeline'
]
=
[
x
for
x
in
cfg
.
eval_pipeline
if
x
[
'type'
]
not
in
skip_type
]
return
cfg
def
to_depth_mode
(
points
,
bboxes
):
"""Convert points and bboxes to Depth Coord and Depth Box mode."""
if
points
is
not
None
:
points
=
Coord3DMode
.
convert_point
(
points
.
copy
(),
Coord3DMode
.
LIDAR
,
Coord3DMode
.
DEPTH
)
if
bboxes
is
not
None
:
bboxes
=
Box3DMode
.
convert
(
bboxes
.
clone
(),
Box3DMode
.
LIDAR
,
Box3DMode
.
DEPTH
)
return
points
,
bboxes
def
show_det_data
(
idx
,
dataset
,
out_dir
,
filename
,
show
=
False
):
"""Visualize 3D point cloud and 3D bboxes."""
example
=
dataset
.
prepare_train_data
(
idx
)
points
=
example
[
'points'
].
_data
.
numpy
()
gt_bboxes
=
dataset
.
get_ann_info
(
idx
)[
'gt_bboxes_3d'
].
tensor
if
dataset
.
box_mode_3d
!=
Box3DMode
.
DEPTH
:
points
,
gt_bboxes
=
to_depth_mode
(
points
,
gt_bboxes
)
show_result
(
points
,
gt_bboxes
.
clone
(),
None
,
out_dir
,
filename
,
show
=
show
,
snapshot
=
True
)
def
show_seg_data
(
idx
,
dataset
,
out_dir
,
filename
,
show
=
False
):
"""Visualize 3D point cloud and segmentation mask."""
example
=
dataset
.
prepare_train_data
(
idx
)
points
=
example
[
'points'
].
_data
.
numpy
()
gt_seg
=
example
[
'pts_semantic_mask'
].
_data
.
numpy
()
show_seg_result
(
points
,
gt_seg
.
copy
(),
None
,
out_dir
,
filename
,
np
.
array
(
dataset
.
PALETTE
),
dataset
.
ignore_index
,
show
=
show
,
snapshot
=
True
)
def
show_proj_bbox_img
(
idx
,
dataset
,
out_dir
,
filename
,
show
=
False
,
is_nus_mono
=
False
):
"""Visualize 3D bboxes on 2D image by projection."""
try
:
example
=
dataset
.
prepare_train_data
(
idx
)
except
AttributeError
:
# for Mono-3D datasets
example
=
dataset
.
prepare_train_img
(
idx
)
gt_bboxes
=
dataset
.
get_ann_info
(
idx
)[
'gt_bboxes_3d'
]
img_metas
=
example
[
'img_metas'
].
_data
img
=
example
[
'img'
].
_data
.
numpy
()
# need to transpose channel to first dim
img
=
img
.
transpose
(
1
,
2
,
0
)
# no 3D gt bboxes, just show img
if
gt_bboxes
.
tensor
.
shape
[
0
]
==
0
:
gt_bboxes
=
None
if
isinstance
(
gt_bboxes
,
DepthInstance3DBoxes
):
show_multi_modality_result
(
img
,
gt_bboxes
,
None
,
None
,
out_dir
,
filename
,
box_mode
=
'depth'
,
img_metas
=
img_metas
,
show
=
show
)
elif
isinstance
(
gt_bboxes
,
LiDARInstance3DBoxes
):
show_multi_modality_result
(
img
,
gt_bboxes
,
None
,
img_metas
[
'lidar2img'
],
out_dir
,
filename
,
box_mode
=
'lidar'
,
img_metas
=
img_metas
,
show
=
show
)
elif
isinstance
(
gt_bboxes
,
CameraInstance3DBoxes
):
show_multi_modality_result
(
img
,
gt_bboxes
,
None
,
img_metas
[
'cam2img'
],
out_dir
,
filename
,
box_mode
=
'camera'
,
img_metas
=
img_metas
,
show
=
show
)
else
:
# can't project, just show img
warnings
.
warn
(
f
'unrecognized gt box type
{
type
(
gt_bboxes
)
}
, only show image'
)
show_multi_modality_result
(
img
,
None
,
None
,
None
,
out_dir
,
filename
,
show
=
show
)
def
main
():
args
=
parse_args
()
if
args
.
output_dir
is
not
None
:
mkdir_or_exist
(
args
.
output_dir
)
cfg
=
build_data_cfg
(
args
.
config
,
args
.
skip_type
,
args
.
cfg_options
)
try
:
dataset
=
build_dataset
(
cfg
.
data
.
train
,
default_args
=
dict
(
filter_empty_gt
=
False
))
except
TypeError
:
# seg dataset doesn't have `filter_empty_gt` key
dataset
=
build_dataset
(
cfg
.
data
.
train
)
data_infos
=
dataset
.
data_infos
dataset_type
=
cfg
.
dataset_type
# configure visualization mode
vis_task
=
args
.
task
# 'det', 'seg', 'multi_modality-det', 'mono-det'
for
idx
,
data_info
in
enumerate
(
track_iter_progress
(
data_infos
)):
if
dataset_type
in
[
'KittiDataset'
,
'WaymoDataset'
]:
data_path
=
data_info
[
'point_cloud'
][
'velodyne_path'
]
elif
dataset_type
in
[
'ScanNetDataset'
,
'SUNRGBDDataset'
,
'ScanNetSegDataset'
,
'S3DISSegDataset'
,
'S3DISDataset'
]:
data_path
=
data_info
[
'pts_path'
]
elif
dataset_type
in
[
'NuScenesDataset'
,
'LyftDataset'
]:
data_path
=
data_info
[
'lidar_path'
]
elif
dataset_type
in
[
'NuScenesMonoDataset'
]:
data_path
=
data_info
[
'file_name'
]
else
:
raise
NotImplementedError
(
f
'unsupported dataset type
{
dataset_type
}
'
)
file_name
=
osp
.
splitext
(
osp
.
basename
(
data_path
))[
0
]
if
vis_task
in
[
'det'
,
'multi_modality-det'
]:
# show 3D bboxes on 3D point clouds
show_det_data
(
idx
,
dataset
,
args
.
output_dir
,
file_name
,
show
=
args
.
online
)
if
vis_task
in
[
'multi_modality-det'
,
'mono-det'
]:
# project 3D bboxes to 2D image
show_proj_bbox_img
(
idx
,
dataset
,
args
.
output_dir
,
file_name
,
show
=
args
.
online
,
is_nus_mono
=
(
dataset_type
==
'NuScenesMonoDataset'
))
elif
vis_task
in
[
'seg'
]:
# show 3D segmentation mask on 3D point clouds
show_seg_data
(
idx
,
dataset
,
args
.
output_dir
,
file_name
,
show
=
args
.
online
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/misc/fuse_conv_bn.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
torch
from
mmcv.runner
import
save_checkpoint
from
torch
import
nn
as
nn
from
mmdet.apis
import
init_model
def
fuse_conv_bn
(
conv
,
bn
):
"""During inference, the functionary of batch norm layers is turned off but
only the mean and var alone channels are used, which exposes the chance to
fuse it with the preceding conv layers to save computations and simplify
network structures."""
conv_w
=
conv
.
weight
conv_b
=
conv
.
bias
if
conv
.
bias
is
not
None
else
torch
.
zeros_like
(
bn
.
running_mean
)
factor
=
bn
.
weight
/
torch
.
sqrt
(
bn
.
running_var
+
bn
.
eps
)
conv
.
weight
=
nn
.
Parameter
(
conv_w
*
factor
.
reshape
([
conv
.
out_channels
,
1
,
1
,
1
]))
conv
.
bias
=
nn
.
Parameter
((
conv_b
-
bn
.
running_mean
)
*
factor
+
bn
.
bias
)
return
conv
def
fuse_module
(
m
):
last_conv
=
None
last_conv_name
=
None
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
(
nn
.
BatchNorm2d
,
nn
.
SyncBatchNorm
)):
if
last_conv
is
None
:
# only fuse BN that is after Conv
continue
fused_conv
=
fuse_conv_bn
(
last_conv
,
child
)
m
.
_modules
[
last_conv_name
]
=
fused_conv
# To reduce changes, set BN as Identity instead of deleting it.
m
.
_modules
[
name
]
=
nn
.
Identity
()
last_conv
=
None
elif
isinstance
(
child
,
nn
.
Conv2d
):
last_conv
=
child
last_conv_name
=
name
else
:
fuse_module
(
child
)
return
m
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'fuse Conv and BN layers in a model'
)
parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file path'
)
parser
.
add_argument
(
'out'
,
help
=
'output path of the converted model'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
# build the model from a config file and a checkpoint file
model
=
init_model
(
args
.
config
,
args
.
checkpoint
)
# fuse conv and bn layers of the model
fused_model
=
fuse_module
(
model
)
save_checkpoint
(
fused_model
,
args
.
out
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/misc/print_config.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
mmcv
import
Config
,
DictAction
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Print the whole config'
)
parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'arguments in dict'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
options
)
print
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/misc/visualize_results.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
mmcv
from
mmcv
import
Config
from
mmdet3d.datasets
import
build_dataset
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet3D visualize the results'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'--result'
,
help
=
'results file in pickle format'
)
parser
.
add_argument
(
'--show-dir'
,
help
=
'directory where visualize results will be saved'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
if
args
.
result
is
not
None
and
\
not
args
.
result
.
endswith
((
'.pkl'
,
'.pickle'
)):
raise
ValueError
(
'The results file must be a pkl file.'
)
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
.
data
.
test
.
test_mode
=
True
# build the dataset
dataset
=
build_dataset
(
cfg
.
data
.
test
)
results
=
mmcv
.
load
(
args
.
result
)
if
getattr
(
dataset
,
'show'
,
None
)
is
not
None
:
# data loading pipeline for showing
eval_pipeline
=
cfg
.
get
(
'eval_pipeline'
,
{})
if
eval_pipeline
:
dataset
.
show
(
results
,
args
.
show_dir
,
pipeline
=
eval_pipeline
)
else
:
dataset
.
show
(
results
,
args
.
show_dir
)
# use default pipeline
else
:
raise
NotImplementedError
(
'Show is not implemented for dataset {}!'
.
format
(
type
(
dataset
).
__name__
))
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/model_converters/convert_votenet_checkpoints.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
tempfile
import
torch
from
mmcv
import
Config
from
mmcv.runner
import
load_state_dict
from
mmdet3d.models
import
build_detector
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet3D upgrade model version(before v0.6.0) of VoteNet'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--out'
,
help
=
'path of the output checkpoint file'
)
args
=
parser
.
parse_args
()
return
args
def
parse_config
(
config_strings
):
"""Parse config from strings.
Args:
config_strings (string): strings of model config.
Returns:
Config: model config
"""
temp_file
=
tempfile
.
NamedTemporaryFile
()
config_path
=
f
'
{
temp_file
.
name
}
.py'
with
open
(
config_path
,
'w'
)
as
f
:
f
.
write
(
config_strings
)
config
=
Config
.
fromfile
(
config_path
)
# Update backbone config
if
'pool_mod'
in
config
.
model
.
backbone
:
config
.
model
.
backbone
.
pop
(
'pool_mod'
)
if
'sa_cfg'
not
in
config
.
model
.
backbone
:
config
.
model
.
backbone
[
'sa_cfg'
]
=
dict
(
type
=
'PointSAModule'
,
pool_mod
=
'max'
,
use_xyz
=
True
,
normalize_xyz
=
True
)
if
'type'
not
in
config
.
model
.
bbox_head
.
vote_aggregation_cfg
:
config
.
model
.
bbox_head
.
vote_aggregation_cfg
[
'type'
]
=
'PointSAModule'
# Update bbox_head config
if
'pred_layer_cfg'
not
in
config
.
model
.
bbox_head
:
config
.
model
.
bbox_head
[
'pred_layer_cfg'
]
=
dict
(
in_channels
=
128
,
shared_conv_channels
=
(
128
,
128
),
bias
=
True
)
if
'feat_channels'
in
config
.
model
.
bbox_head
:
config
.
model
.
bbox_head
.
pop
(
'feat_channels'
)
if
'vote_moudule_cfg'
in
config
.
model
.
bbox_head
:
config
.
model
.
bbox_head
[
'vote_module_cfg'
]
=
config
.
model
.
bbox_head
.
pop
(
'vote_moudule_cfg'
)
if
config
.
model
.
bbox_head
.
vote_aggregation_cfg
.
use_xyz
:
config
.
model
.
bbox_head
.
vote_aggregation_cfg
.
mlp_channels
[
0
]
-=
3
temp_file
.
close
()
return
config
def
main
():
"""Convert keys in checkpoints for VoteNet.
There can be some breaking changes during the development of mmdetection3d,
and this tool is used for upgrading checkpoints trained with old versions
(before v0.6.0) to the latest one.
"""
args
=
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint
)
cfg
=
parse_config
(
checkpoint
[
'meta'
][
'config'
])
# Build the model and load checkpoint
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
get
(
'train_cfg'
),
test_cfg
=
cfg
.
get
(
'test_cfg'
))
orig_ckpt
=
checkpoint
[
'state_dict'
]
converted_ckpt
=
orig_ckpt
.
copy
()
if
cfg
[
'dataset_type'
]
==
'ScanNetDataset'
:
NUM_CLASSES
=
18
elif
cfg
[
'dataset_type'
]
==
'SUNRGBDDataset'
:
NUM_CLASSES
=
10
else
:
raise
NotImplementedError
RENAME_PREFIX
=
{
'bbox_head.conv_pred.0'
:
'bbox_head.conv_pred.shared_convs.layer0'
,
'bbox_head.conv_pred.1'
:
'bbox_head.conv_pred.shared_convs.layer1'
}
DEL_KEYS
=
[
'bbox_head.conv_pred.0.bn.num_batches_tracked'
,
'bbox_head.conv_pred.1.bn.num_batches_tracked'
]
EXTRACT_KEYS
=
{
'bbox_head.conv_pred.conv_cls.weight'
:
(
'bbox_head.conv_pred.conv_out.weight'
,
[(
0
,
2
),
(
-
NUM_CLASSES
,
-
1
)]),
'bbox_head.conv_pred.conv_cls.bias'
:
(
'bbox_head.conv_pred.conv_out.bias'
,
[(
0
,
2
),
(
-
NUM_CLASSES
,
-
1
)]),
'bbox_head.conv_pred.conv_reg.weight'
:
(
'bbox_head.conv_pred.conv_out.weight'
,
[(
2
,
-
NUM_CLASSES
)]),
'bbox_head.conv_pred.conv_reg.bias'
:
(
'bbox_head.conv_pred.conv_out.bias'
,
[(
2
,
-
NUM_CLASSES
)])
}
# Delete some useless keys
for
key
in
DEL_KEYS
:
converted_ckpt
.
pop
(
key
)
# Rename keys with specific prefix
RENAME_KEYS
=
dict
()
for
old_key
in
converted_ckpt
.
keys
():
for
rename_prefix
in
RENAME_PREFIX
.
keys
():
if
rename_prefix
in
old_key
:
new_key
=
old_key
.
replace
(
rename_prefix
,
RENAME_PREFIX
[
rename_prefix
])
RENAME_KEYS
[
new_key
]
=
old_key
for
new_key
,
old_key
in
RENAME_KEYS
.
items
():
converted_ckpt
[
new_key
]
=
converted_ckpt
.
pop
(
old_key
)
# Extract weights and rename the keys
for
new_key
,
(
old_key
,
indices
)
in
EXTRACT_KEYS
.
items
():
cur_layers
=
orig_ckpt
[
old_key
]
converted_layers
=
[]
for
(
start
,
end
)
in
indices
:
if
end
!=
-
1
:
converted_layers
.
append
(
cur_layers
[
start
:
end
])
else
:
converted_layers
.
append
(
cur_layers
[
start
:])
converted_layers
=
torch
.
cat
(
converted_layers
,
0
)
converted_ckpt
[
new_key
]
=
converted_layers
if
old_key
in
converted_ckpt
.
keys
():
converted_ckpt
.
pop
(
old_key
)
# Check the converted checkpoint by loading to the model
load_state_dict
(
model
,
converted_ckpt
,
strict
=
True
)
checkpoint
[
'state_dict'
]
=
converted_ckpt
torch
.
save
(
checkpoint
,
args
.
out
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/model_converters/publish_model.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
subprocess
import
torch
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Process a checkpoint to be published'
)
parser
.
add_argument
(
'in_file'
,
help
=
'input checkpoint filename'
)
parser
.
add_argument
(
'out_file'
,
help
=
'output checkpoint filename'
)
args
=
parser
.
parse_args
()
return
args
def
process_checkpoint
(
in_file
,
out_file
):
checkpoint
=
torch
.
load
(
in_file
,
map_location
=
'cpu'
)
# remove optimizer for smaller file size
if
'optimizer'
in
checkpoint
:
del
checkpoint
[
'optimizer'
]
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch
.
save
(
checkpoint
,
out_file
)
sha
=
subprocess
.
check_output
([
'sha256sum'
,
out_file
]).
decode
()
final_file
=
out_file
.
rstrip
(
'.pth'
)
+
'-{}.pth'
.
format
(
sha
[:
8
])
subprocess
.
Popen
([
'mv'
,
out_file
,
final_file
])
def
main
():
args
=
parse_args
()
process_checkpoint
(
args
.
in_file
,
args
.
out_file
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/model_converters/regnet2mmdet.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
torch
from
collections
import
OrderedDict
def
convert_stem
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
new_key
=
model_key
.
replace
(
'stem.conv'
,
'conv1'
)
new_key
=
new_key
.
replace
(
'stem.bn'
,
'bn1'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_head
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
new_key
=
model_key
.
replace
(
'head.fc'
,
'fc'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_reslayer
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
split_keys
=
model_key
.
split
(
'.'
)
layer
,
block
,
module
=
split_keys
[:
3
]
block_id
=
int
(
block
[
1
:])
layer_name
=
f
'layer
{
int
(
layer
[
1
:])
}
'
block_name
=
f
'
{
block_id
-
1
}
'
if
block_id
==
1
and
module
==
'bn'
:
new_key
=
f
'
{
layer_name
}
.
{
block_name
}
.downsample.1.
{
split_keys
[
-
1
]
}
'
elif
block_id
==
1
and
module
==
'proj'
:
new_key
=
f
'
{
layer_name
}
.
{
block_name
}
.downsample.0.
{
split_keys
[
-
1
]
}
'
elif
module
==
'f'
:
if
split_keys
[
3
]
==
'a_bn'
:
module_name
=
'bn1'
elif
split_keys
[
3
]
==
'b_bn'
:
module_name
=
'bn2'
elif
split_keys
[
3
]
==
'c_bn'
:
module_name
=
'bn3'
elif
split_keys
[
3
]
==
'a'
:
module_name
=
'conv1'
elif
split_keys
[
3
]
==
'b'
:
module_name
=
'conv2'
elif
split_keys
[
3
]
==
'c'
:
module_name
=
'conv3'
new_key
=
f
'
{
layer_name
}
.
{
block_name
}
.
{
module_name
}
.
{
split_keys
[
-
1
]
}
'
else
:
raise
ValueError
(
f
'Unsupported conversion of key
{
model_key
}
'
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
def
convert
(
src
,
dst
):
"""Convert keys in pycls pretrained RegNet models to mmdet style."""
# load caffe model
regnet_model
=
torch
.
load
(
src
)
blobs
=
regnet_model
[
'model_state'
]
# convert to pytorch style
state_dict
=
OrderedDict
()
converted_names
=
set
()
for
key
,
weight
in
blobs
.
items
():
if
'stem'
in
key
:
convert_stem
(
key
,
weight
,
state_dict
,
converted_names
)
elif
'head'
in
key
:
convert_head
(
key
,
weight
,
state_dict
,
converted_names
)
elif
key
.
startswith
(
's'
):
convert_reslayer
(
key
,
weight
,
state_dict
,
converted_names
)
# check if all layers are converted
for
key
in
blobs
:
if
key
not
in
converted_names
:
print
(
f
'not converted:
{
key
}
'
)
# save checkpoint
checkpoint
=
dict
()
checkpoint
[
'state_dict'
]
=
state_dict
torch
.
save
(
checkpoint
,
dst
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
convert
(
args
.
src
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/test.py
0 → 100644
View file @
19472568
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
argparse
import
mmcv
import
os
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
)
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.datasets
import
build_dataset
from
projects.mmdet3d_plugin.datasets.builder
import
build_dataloader
from
mmdet3d.models
import
build_model
from
mmdet.apis
import
set_random_seed
from
projects.mmdet3d_plugin.bevformer.apis.test
import
custom_multi_gpu_test
from
mmdet.datasets
import
replace_ImageToTensor
import
time
import
os.path
as
osp
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet test (and eval) a model'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--out'
,
help
=
'output result file in pickle format'
)
parser
.
add_argument
(
'--fuse-conv-bn'
,
action
=
'store_true'
,
help
=
'Whether to fuse conv and bn, this will slightly increase'
'the inference speed'
)
parser
.
add_argument
(
'--format-only'
,
action
=
'store_true'
,
help
=
'Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server'
)
parser
.
add_argument
(
'--eval'
,
type
=
str
,
nargs
=
'+'
,
help
=
'evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
parser
.
add_argument
(
'--show-dir'
,
help
=
'directory where results will be saved'
)
parser
.
add_argument
(
'--gpu-collect'
,
action
=
'store_true'
,
help
=
'whether to use gpu to collect results.'
)
parser
.
add_argument
(
'--tmpdir'
,
help
=
'tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.'
)
parser
.
add_argument
(
'--eval-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
if
args
.
options
and
args
.
eval_options
:
raise
ValueError
(
'--options and --eval-options cannot be both specified, '
'--options is deprecated in favor of --eval-options'
)
if
args
.
options
:
warnings
.
warn
(
'--options is deprecated in favor of --eval-options'
)
args
.
eval_options
=
args
.
options
return
args
def
main
():
args
=
parse_args
()
assert
args
.
out
or
args
.
eval
or
args
.
format_only
or
args
.
show
\
or
args
.
show_dir
,
\
(
'Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"'
)
if
args
.
eval
and
args
.
format_only
:
raise
ValueError
(
'--eval and --format_only cannot be both specified'
)
if
args
.
out
is
not
None
and
not
args
.
out
.
endswith
((
'.pkl'
,
'.pickle'
)):
raise
ValueError
(
'The output file must be a pkl file.'
)
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# import modules from plguin/xx, registry will be updated
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
plugin_dir
=
cfg
.
plugin_dir
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
cfg
.
model
.
pretrained
=
None
# in case the test dataset is concatenated
samples_per_gpu
=
1
if
isinstance
(
cfg
.
data
.
test
,
dict
):
cfg
.
data
.
test
.
test_mode
=
True
samples_per_gpu
=
cfg
.
data
.
test
.
pop
(
'samples_per_gpu'
,
1
)
if
samples_per_gpu
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
test
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
test
.
pipeline
)
elif
isinstance
(
cfg
.
data
.
test
,
list
):
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
test_mode
=
True
samples_per_gpu
=
max
(
[
ds_cfg
.
pop
(
'samples_per_gpu'
,
1
)
for
ds_cfg
in
cfg
.
data
.
test
])
if
samples_per_gpu
>
1
:
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
pipeline
=
replace_ImageToTensor
(
ds_cfg
.
pipeline
)
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the dataloader
dataset
=
build_dataset
(
cfg
.
data
.
test
)
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
,
nonshuffler_sampler
=
cfg
.
data
.
nonshuffler_sampler
,
)
# build the model and load checkpoint
cfg
.
model
.
train_cfg
=
None
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
if
args
.
fuse_conv_bn
:
model
=
fuse_conv_bn
(
model
)
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if
'CLASSES'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
CLASSES
=
checkpoint
[
'meta'
][
'CLASSES'
]
else
:
model
.
CLASSES
=
dataset
.
CLASSES
# palette for visualization in segmentation tasks
if
'PALETTE'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
PALETTE
=
checkpoint
[
'meta'
][
'PALETTE'
]
elif
hasattr
(
dataset
,
'PALETTE'
):
# segmentation dataset has `PALETTE` attribute
model
.
PALETTE
=
dataset
.
PALETTE
if
not
distributed
:
assert
False
# model = MMDataParallel(model, device_ids=[0])
# outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
else
:
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
)
outputs
=
custom_multi_gpu_test
(
model
,
data_loader
,
args
.
tmpdir
,
args
.
gpu_collect
)
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
if
args
.
out
:
print
(
f
'
\n
writing results to
{
args
.
out
}
'
)
assert
False
#mmcv.dump(outputs['bbox_results'], args.out)
kwargs
=
{}
if
args
.
eval_options
is
None
else
args
.
eval_options
kwargs
[
'jsonfile_prefix'
]
=
osp
.
join
(
'test'
,
args
.
config
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
-
2
],
time
.
ctime
().
replace
(
' '
,
'_'
).
replace
(
':'
,
'_'
))
if
args
.
format_only
:
dataset
.
format_results
(
outputs
,
**
kwargs
)
if
args
.
eval
:
eval_kwargs
=
cfg
.
get
(
'evaluation'
,
{}).
copy
()
# hard-code way to remove EvalHook args
for
key
in
[
'interval'
,
'tmpdir'
,
'start'
,
'gpu_collect'
,
'save_best'
,
'rule'
]:
eval_kwargs
.
pop
(
key
,
None
)
eval_kwargs
.
update
(
dict
(
metric
=
args
.
eval
,
**
kwargs
))
print
(
dataset
.
evaluate
(
outputs
,
**
eval_kwargs
))
if
__name__
==
'__main__'
:
main
()
docker-hub/MapTRv2/MapTR/tools/train.py
0 → 100644
View file @
19472568
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
__future__
import
division
import
argparse
import
copy
import
mmcv
import
os
import
time
import
torch
import
warnings
from
mmcv
import
Config
,
DictAction
from
mmcv.runner
import
get_dist_info
,
init_dist
from
os
import
path
as
osp
from
mmdet
import
__version__
as
mmdet_version
from
mmdet3d
import
__version__
as
mmdet3d_version
#from mmdet3d.apis import train_model
from
mmdet3d.datasets
import
build_dataset
from
mmdet3d.models
import
build_model
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmdet.apis
import
set_random_seed
from
mmseg
import
__version__
as
mmseg_version
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--no-validate'
,
action
=
'store_true'
,
help
=
'whether not to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'ids of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--autoscale-lr'
,
action
=
'store_true'
,
help
=
'automatically scale lr with the number of gpus'
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
if
args
.
options
and
args
.
cfg_options
:
raise
ValueError
(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options'
)
if
args
.
options
:
warnings
.
warn
(
'--options is deprecated in favor of --cfg-options'
)
args
.
cfg_options
=
args
.
options
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# import modules from plguin/xx, registry will be updated
if
hasattr
(
cfg
,
'plugin'
):
if
cfg
.
plugin
:
import
importlib
if
hasattr
(
cfg
,
'plugin_dir'
):
plugin_dir
=
cfg
.
plugin_dir
_module_dir
=
os
.
path
.
dirname
(
plugin_dir
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
else
:
# import dir is the dirpath for the config file
_module_dir
=
os
.
path
.
dirname
(
args
.
config
)
_module_dir
=
_module_dir
.
split
(
'/'
)
_module_path
=
_module_dir
[
0
]
for
m
in
_module_dir
[
1
:]:
_module_path
=
_module_path
+
'.'
+
m
print
(
_module_path
)
plg_lib
=
importlib
.
import_module
(
_module_path
)
from
projects.mmdet3d_plugin.bevformer.apis.train
import
custom_train_model
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
# if args.resume_from is not None:
if
args
.
resume_from
is
not
None
and
osp
.
isfile
(
args
.
resume_from
):
cfg
.
resume_from
=
args
.
resume_from
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
else
:
cfg
.
gpu_ids
=
range
(
1
)
if
args
.
gpus
is
None
else
range
(
args
.
gpus
)
if
digit_version
(
TORCH_VERSION
)
==
digit_version
(
'1.8.1'
)
and
cfg
.
optimizer
[
'type'
]
==
'AdamW'
:
cfg
.
optimizer
[
'type'
]
=
'AdamW2'
# fix bug in Adamw
if
args
.
autoscale_lr
:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
len
(
cfg
.
gpu_ids
)
/
8
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# re-set gpu_ids with distributed training mode
_
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
# create work_dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# dump config
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if
cfg
.
model
.
type
in
[
'EncoderDecoder3D'
]:
logger_name
=
'mmseg'
else
:
logger_name
=
'mmdet'
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
,
name
=
logger_name
)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta
=
dict
()
# log env info
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
meta
[
'env_info'
]
=
env_info
meta
[
'config'
]
=
cfg
.
pretty_text
# log some basic info
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
# set random seeds
if
args
.
seed
is
not
None
:
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
cfg
.
seed
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
get
(
'train_cfg'
),
test_cfg
=
cfg
.
get
(
'test_cfg'
))
model
.
img_backbone
=
model
.
img_backbone
.
to
(
memory_format
=
torch
.
channels_last
)
model
.
init_weights
()
logger
.
info
(
f
'Model:
\n
{
model
}
'
)
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
# in case we use a dataset wrapper
if
'dataset'
in
cfg
.
data
.
train
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
dataset
.
pipeline
else
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
pipeline
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset
.
test_mode
=
False
datasets
.
append
(
build_dataset
(
val_dataset
))
if
cfg
.
checkpoint_config
is
not
None
:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
mmdet_version
,
mmseg_version
=
mmseg_version
,
mmdet3d_version
=
mmdet3d_version
,
config
=
cfg
.
pretty_text
,
CLASSES
=
datasets
[
0
].
CLASSES
,
PALETTE
=
datasets
[
0
].
PALETTE
# for segmentors
if
hasattr
(
datasets
[
0
],
'PALETTE'
)
else
None
)
# add an attribute for visualization convenience
model
.
CLASSES
=
datasets
[
0
].
CLASSES
torch
.
backends
.
cudnn
.
benchmark
=
True
# 启用自动寻找最优卷积算法
torch
.
backends
.
cudnn
.
deterministic
=
False
# 允许非确定性算法提升速度
custom_train_model
(
model
,
datasets
,
cfg
,
distributed
=
distributed
,
validate
=
(
not
args
.
no_validate
),
timestamp
=
timestamp
,
meta
=
meta
)
if
__name__
==
'__main__'
:
torch
.
multiprocessing
.
set_start_method
(
'fork'
)
main
()
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment