Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
RODNet
Commits
a8863510
Commit
a8863510
authored
Nov 10, 2020
by
Yizhou Wang
Browse files
v1.0: first commit
parent
16d8dda7
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
881 additions
and
0 deletions
+881
-0
rodnet/utils/visualization/fig_configs.py
rodnet/utils/visualization/fig_configs.py
+12
-0
rodnet/utils/visualization/postprocessing.py
rodnet/utils/visualization/postprocessing.py
+7
-0
rodnet/utils/visualization/ramap.py
rodnet/utils/visualization/ramap.py
+150
-0
tools/prepare_dataset/prepare_data.py
tools/prepare_dataset/prepare_data.py
+168
-0
tools/test.py
tools/test.py
+255
-0
tools/train.py
tools/train.py
+289
-0
No files found.
rodnet/utils/visualization/fig_configs.py
0 → 100644
View file @
a8863510
# -*- coding: utf-8 -*-
import
matplotlib.pyplot
as
plt
from
matplotlib.font_manager
import
FontProperties
fig
=
plt
.
figure
(
figsize
=
(
8
,
8
))
fp
=
FontProperties
(
fname
=
r
"assets/fontawesome-free-5.12.0-desktop/otfs/solid-900.otf"
)
symbols
=
{
'pedestrian'
:
"
\uf554
"
,
'cyclist'
:
"
\uf84a
"
,
'car'
:
"
\uf1b9
"
,
}
rodnet/utils/visualization/postprocessing.py
0 → 100644
View file @
a8863510
import
matplotlib.pyplot
as
plt
def
visualize_ols_hist
(
olss_flatten
):
_
=
plt
.
hist
(
olss_flatten
,
bins
=
'auto'
)
# arguments are passed to np.histogram
plt
.
title
(
"OLS Distribution"
)
plt
.
show
()
rodnet/utils/visualization/ramap.py
0 → 100644
View file @
a8863510
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
rodnet.core.radar_processing.chirp_ops
import
chirp_amp
def
visualize_radar_chirp
(
chirp
,
radar_data_type
):
"""
Visualize radar data of one chirp
:param chirp: (w x h x 2) or (2 x w x h)
:param radar_data_type: current available types include 'RI', 'RISEP', 'AP', 'APSEP'
:return:
"""
chirp_abs
=
chirp_amp
(
chirp
,
radar_data_type
)
plt
.
imshow
(
chirp_abs
)
plt
.
show
()
def
visualize_radar_chirps
(
chirps
,
radar_data_type
):
"""
Visualize radar data of multiple chirps
:param chirps: (N x w x h x 2) or (N x 2 x w x h)
:param radar_data_type: current available types include 'RI', 'RISEP', 'AP', 'APSEP'
:return:
"""
num_chirps
,
c0
,
c1
,
c2
=
chirps
.
shape
if
c2
==
2
:
chirps_abs
=
np
.
zeros
((
num_chirps
,
c0
,
c1
))
elif
c0
==
2
:
chirps_abs
=
np
.
zeros
((
num_chirps
,
c1
,
c2
))
else
:
raise
ValueError
for
chirp_id
in
range
(
num_chirps
):
chirps_abs
[
chirp_id
,
:,
:]
=
chirp_amp
(
chirps
[
chirp_id
,
:,
:,
:],
radar_data_type
)
chirp_abs_avg
=
np
.
mean
(
chirps_abs
,
axis
=
0
)
plt
.
imshow
(
chirp_abs_avg
)
plt
.
show
()
def
visualize_fuse_crdets
(
chirp
,
obj_dicts
,
figname
=
None
,
viz
=
False
):
chirp_abs
=
chirp_amp
(
chirp
)
chirp_shape
=
chirp_abs
.
shape
plt
.
figure
()
plt
.
imshow
(
chirp_abs
,
vmin
=
0
,
vmax
=
1
,
origin
=
'lower'
)
for
obj_id
,
obj_dict
in
enumerate
(
obj_dicts
):
plt
.
scatter
(
obj_dict
[
'angle_id'
],
obj_dict
[
'range_id'
],
s
=
10
,
c
=
'white'
)
try
:
text
=
str
(
obj_dict
[
'object_id'
])
+
' '
+
obj_dict
[
'class'
]
except
:
text
=
str
(
obj_dict
[
'object_id'
])
plt
.
text
(
obj_dict
[
'angle_id'
]
+
5
,
obj_dict
[
'range_id'
],
text
,
color
=
'white'
,
fontsize
=
10
)
plt
.
xlim
(
0
,
chirp_shape
[
1
])
plt
.
ylim
(
0
,
chirp_shape
[
0
])
if
viz
:
plt
.
show
()
else
:
plt
.
savefig
(
figname
)
plt
.
close
()
def
visualize_fuse_crdets_compare
(
img_path
,
chirp
,
c_dicts
,
r_dicts
,
cr_dicts
,
figname
=
None
,
viz
=
False
):
chirp_abs
=
chirp_amp
(
chirp
)
chirp_shape
=
chirp_abs
.
shape
fig_local
=
plt
.
figure
()
fig_local
.
set_size_inches
(
16
,
4
)
fig_local
.
add_subplot
(
1
,
4
,
1
)
im
=
plt
.
imread
(
img_path
)
plt
.
imshow
(
im
)
fig_local
.
add_subplot
(
1
,
4
,
2
)
plt
.
imshow
(
chirp_abs
,
vmin
=
0
,
vmax
=
1
,
origin
=
'lower'
)
for
obj_id
,
obj_dict
in
enumerate
(
c_dicts
):
plt
.
scatter
(
obj_dict
[
'angle_id'
],
obj_dict
[
'range_id'
],
s
=
10
,
c
=
'white'
)
try
:
obj_dict
[
'object_id'
]
except
:
obj_dict
[
'object_id'
]
=
''
try
:
text
=
str
(
obj_dict
[
'object_id'
])
+
' '
+
obj_dict
[
'class'
]
except
:
text
=
str
(
obj_dict
[
'object_id'
])
plt
.
text
(
obj_dict
[
'angle_id'
]
+
5
,
obj_dict
[
'range_id'
],
text
,
color
=
'white'
,
fontsize
=
10
)
plt
.
xlim
(
0
,
chirp_shape
[
1
])
plt
.
ylim
(
0
,
chirp_shape
[
0
])
fig_local
.
add_subplot
(
1
,
4
,
3
)
plt
.
imshow
(
chirp_abs
,
vmin
=
0
,
vmax
=
1
,
origin
=
'lower'
)
for
obj_id
,
obj_dict
in
enumerate
(
r_dicts
):
plt
.
scatter
(
obj_dict
[
'angle_id'
],
obj_dict
[
'range_id'
],
s
=
10
,
c
=
'white'
)
try
:
obj_dict
[
'object_id'
]
except
:
obj_dict
[
'object_id'
]
=
''
try
:
text
=
str
(
obj_dict
[
'object_id'
])
+
' '
+
obj_dict
[
'class'
]
except
:
text
=
str
(
obj_dict
[
'object_id'
])
plt
.
text
(
obj_dict
[
'angle_id'
]
+
5
,
obj_dict
[
'range_id'
],
text
,
color
=
'white'
,
fontsize
=
10
)
plt
.
xlim
(
0
,
chirp_shape
[
1
])
plt
.
ylim
(
0
,
chirp_shape
[
0
])
fig_local
.
add_subplot
(
1
,
4
,
4
)
plt
.
imshow
(
chirp_abs
,
vmin
=
0
,
vmax
=
1
,
origin
=
'lower'
)
for
obj_id
,
obj_dict
in
enumerate
(
cr_dicts
):
plt
.
scatter
(
obj_dict
[
'angle_id'
],
obj_dict
[
'range_id'
],
s
=
10
,
c
=
'white'
)
try
:
obj_dict
[
'object_id'
]
except
:
obj_dict
[
'object_id'
]
=
'%.2f'
%
obj_dict
[
'confidence'
]
try
:
text
=
str
(
obj_dict
[
'object_id'
])
+
' '
+
obj_dict
[
'class'
]
except
:
text
=
str
(
obj_dict
[
'object_id'
])
plt
.
text
(
obj_dict
[
'angle_id'
]
+
5
,
obj_dict
[
'range_id'
],
text
,
color
=
'white'
,
fontsize
=
10
)
plt
.
xlim
(
0
,
chirp_shape
[
1
])
plt
.
ylim
(
0
,
chirp_shape
[
0
])
if
viz
:
plt
.
show
()
else
:
plt
.
savefig
(
figname
)
plt
.
close
()
def
visualize_anno_ramap
(
chirp
,
obj_info
,
figname
,
viz
=
False
):
chirp_abs
=
chirp_amp
(
chirp
)
plt
.
figure
()
plt
.
imshow
(
chirp_abs
,
vmin
=
0
,
vmax
=
1
,
origin
=
'lower'
)
for
obj
in
obj_info
:
rng_idx
,
agl_idx
,
class_id
=
obj
if
class_id
>=
0
:
try
:
cla_str
=
class_table
[
class_id
]
except
:
continue
else
:
continue
plt
.
scatter
(
agl_idx
,
rng_idx
,
s
=
10
,
c
=
'white'
)
plt
.
text
(
agl_idx
+
5
,
rng_idx
,
cla_str
,
color
=
'white'
,
fontsize
=
10
)
if
viz
:
plt
.
show
()
else
:
plt
.
savefig
(
figname
)
plt
.
close
()
tools/prepare_dataset/prepare_data.py
0 → 100644
View file @
a8863510
import
os
import
sys
import
shutil
import
numpy
as
np
import
json
import
pickle
import
argparse
from
cruw.cruw
import
CRUW
from
rodnet.core.confidence_map
import
generate_confmap
,
normalize_confmap
,
add_noise_channel
from
rodnet.utils.load_configs
import
load_configs_from_file
from
rodnet.utils.visualization
import
visualize_confmap
SPLITS_LIST
=
[
'train'
,
'valid'
,
'test'
,
'demo'
]
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Prepare RODNet data.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
dest
=
'config'
,
help
=
'configuration file path'
)
parser
.
add_argument
(
'--data_root'
,
type
=
str
,
help
=
'directory to the prepared data'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
dest
=
'split'
,
help
=
'choose from train, valid, test, supertest'
)
parser
.
add_argument
(
'--out_data_dir'
,
type
=
str
,
default
=
'./data'
,
help
=
'data directory to save the prepared data'
)
parser
.
add_argument
(
'--overwrite'
,
action
=
"store_true"
,
help
=
"overwrite prepared data if exist"
)
args
=
parser
.
parse_args
()
return
args
def
prepare_data
(
dataset
,
config_dict
,
data_dir
,
split
,
save_dir
,
viz
=
False
,
overwrite
=
False
):
"""
Prepare pickle data for RODNet training and testing
:param dataset: dataset object
:param config_dict: rodnet configurations
:param data_dir: output directory of the processed data
:param split: train, valid, test, demo, etc.
:param viz: whether visualize the prepared data
:param overwrite: whether overwrite the existing prepared data
:return:
"""
camera_configs
=
dataset
.
sensor_cfg
.
camera_cfg
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
n_chirp
=
radar_configs
[
'n_chirps'
]
n_class
=
dataset
.
object_cfg
.
n_class
data_root
=
config_dict
[
'dataset_cfg'
][
'data_root'
]
anno_root
=
config_dict
[
'dataset_cfg'
][
'anno_root'
]
set_cfg
=
config_dict
[
'dataset_cfg'
][
split
]
sets_seqs
=
set_cfg
[
'seqs'
]
if
overwrite
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
data_dir
,
split
)):
shutil
.
rmtree
(
os
.
path
.
join
(
data_dir
,
split
))
os
.
makedirs
(
os
.
path
.
join
(
data_dir
,
split
))
for
seq
in
sets_seqs
:
seq_path
=
os
.
path
.
join
(
data_root
,
seq
)
seq_anno_path
=
os
.
path
.
join
(
anno_root
,
seq
+
'.json'
)
save_path
=
os
.
path
.
join
(
save_dir
,
seq
+
'.pkl'
)
print
(
"Sequence %s saving to %s"
%
(
seq_path
,
save_path
))
try
:
if
not
overwrite
and
os
.
path
.
exists
(
save_path
):
print
(
"%s already exists, skip"
%
save_path
)
continue
image_dir
=
os
.
path
.
join
(
seq_path
,
camera_configs
[
'image_folder'
])
image_paths
=
sorted
([
os
.
path
.
join
(
image_dir
,
name
)
for
name
in
os
.
listdir
(
image_dir
)
if
name
.
endswith
(
camera_configs
[
'ext'
])])
n_frame
=
len
(
image_paths
)
radar_dir
=
os
.
path
.
join
(
seq_path
,
dataset
.
sensor_cfg
.
radar_cfg
[
'chirp_folder'
])
if
radar_configs
[
'data_type'
]
==
'RI'
or
radar_configs
[
'data_type'
]
==
'AP'
:
radar_paths
=
sorted
([
os
.
path
.
join
(
radar_dir
,
name
)
for
name
in
os
.
listdir
(
radar_dir
)
if
name
.
endswith
(
dataset
.
sensor_cfg
.
radar_cfg
[
'ext'
])])
n_radar_frame
=
len
(
radar_paths
)
assert
n_frame
==
n_radar_frame
elif
radar_configs
[
'data_type'
]
==
'RISEP'
or
radar_configs
[
'data_type'
]
==
'APSEP'
:
radar_paths_chirp
=
[]
for
chirp_id
in
range
(
n_chirp
):
chirp_dir
=
os
.
path
.
join
(
radar_dir
,
'%04d'
%
chirp_id
)
paths
=
sorted
([
os
.
path
.
join
(
chirp_dir
,
name
)
for
name
in
os
.
listdir
(
chirp_dir
)
if
name
.
endswith
(
config_dict
[
'dataset_cfg'
][
'radar_cfg'
][
'ext'
])])
n_radar_frame
=
len
(
paths
)
assert
n_frame
==
n_radar_frame
radar_paths_chirp
.
append
(
paths
)
radar_paths
=
[]
for
frame_id
in
range
(
n_frame
):
frame_paths
=
[]
for
chirp_id
in
range
(
n_chirp
):
frame_paths
.
append
(
radar_paths_chirp
[
chirp_id
][
frame_id
])
radar_paths
.
append
(
frame_paths
)
else
:
raise
ValueError
data_dict
=
dict
(
data_root
=
data_root
,
data_path
=
seq_path
,
seq_name
=
seq
,
n_frame
=
n_frame
,
image_paths
=
image_paths
,
radar_paths
=
radar_paths
,
anno
=
None
,
)
if
split
==
'demo'
:
# no labels need to be saved
pickle
.
dump
(
data_dict
,
open
(
save_path
,
'wb'
))
continue
else
:
with
open
(
os
.
path
.
join
(
seq_anno_path
),
'r'
)
as
f
:
anno
=
json
.
load
(
f
)
anno_obj
=
{}
anno_obj
[
'metadata'
]
=
anno
[
'metadata'
]
anno_obj
[
'confmaps'
]
=
[]
for
metadata_frame
in
anno
[
'metadata'
]:
n_obj
=
metadata_frame
[
'rad_h'
][
'n_objects'
]
obj_info
=
metadata_frame
[
'rad_h'
][
'obj_info'
]
if
n_obj
==
0
:
confmap_gt
=
np
.
zeros
(
(
n_class
+
1
,
radar_configs
[
'ramap_rsize'
],
radar_configs
[
'ramap_asize'
]),
dtype
=
float
)
confmap_gt
[
-
1
,
:,
:]
=
1.0
# initialize noise channal
else
:
confmap_gt
=
generate_confmap
(
n_obj
,
obj_info
,
dataset
,
config_dict
)
confmap_gt
=
normalize_confmap
(
confmap_gt
)
confmap_gt
=
add_noise_channel
(
confmap_gt
,
dataset
,
config_dict
)
assert
confmap_gt
.
shape
==
(
n_class
+
1
,
radar_configs
[
'ramap_rsize'
],
radar_configs
[
'ramap_asize'
])
if
viz
:
visualize_confmap
(
confmap_gt
)
anno_obj
[
'confmaps'
].
append
(
confmap_gt
)
# end objects loop
anno_obj
[
'confmaps'
]
=
np
.
array
(
anno_obj
[
'confmaps'
])
data_dict
[
'anno'
]
=
anno_obj
# save pkl files
pickle
.
dump
(
data_dict
,
open
(
save_path
,
'wb'
))
# end frames loop
except
Exception
as
e
:
print
(
"Error while preparing %s: %s"
%
(
seq_path
,
e
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
data_root
=
args
.
data_root
splits
=
args
.
split
.
split
(
','
)
out_data_dir
=
args
.
out_data_dir
overwrite
=
args
.
overwrite
dataset
=
CRUW
(
data_root
=
data_root
)
config_dict
=
load_configs_from_file
(
args
.
config
)
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
for
split
in
splits
:
if
split
not
in
SPLITS_LIST
:
raise
TypeError
(
"split %s cannot be recognized"
%
split
)
for
split
in
splits
:
save_dir
=
os
.
path
.
join
(
out_data_dir
,
split
)
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
print
(
'Preparing %s sets ...'
%
split
)
prepare_data
(
dataset
,
config_dict
,
out_data_dir
,
split
,
save_dir
,
viz
=
False
,
overwrite
=
overwrite
)
tools/test.py
0 → 100644
View file @
a8863510
import
os
import
time
import
argparse
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
cruw.cruw
import
CRUW
from
rodnet.datasets.CRDataset
import
CRDataset
from
rodnet.datasets.collate_functions
import
cr_collate
from
rodnet.core.post_processing
import
post_process
,
post_process_single_frame
from
rodnet.core.post_processing
import
write_dets_results
,
write_dets_results_single_frame
from
rodnet.core.post_processing
import
ConfmapStack
from
rodnet.core.radar_processing
import
chirp_amp
from
rodnet.utils.visualization
import
visualize_test_img
,
visualize_test_img_wo_gt
from
rodnet.utils.load_configs
import
load_configs_from_file
from
rodnet.utils.solve_dir
import
create_random_model_name
"""
Example:
python test.py -m HG -dd /mnt/ssd2/rodnet/data/ -ld /mnt/ssd2/rodnet/checkpoints/
\
-md HG-20200122-104604 -rd /mnt/ssd2/rodnet/results/
"""
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Test RODNet.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'choose rodnet model configurations'
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
'./data/'
,
help
=
'directory to the prepared data'
)
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
help
=
'path to the saved trained model'
)
parser
.
add_argument
(
'--res_dir'
,
type
=
str
,
default
=
'./results/'
,
help
=
'directory to save testing results'
)
parser
.
add_argument
(
'--use_noise_channel'
,
action
=
"store_true"
,
help
=
"use noise channel or not"
)
parser
.
add_argument
(
'--demo'
,
action
=
"store_true"
,
help
=
'False: test with GT, True: demo without GT'
)
parser
.
add_argument
(
'--symbol'
,
action
=
"store_true"
,
help
=
'use symbol or text+score'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
"__main__"
:
args
=
parse_args
()
sybl
=
args
.
symbol
config_dict
=
load_configs_from_file
(
args
.
config
)
dataset
=
CRUW
(
data_root
=
config_dict
[
'dataset_cfg'
][
'base_root'
])
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
range_grid
=
dataset
.
range_grid
angle_grid
=
dataset
.
angle_grid
# config_dict['mappings'] = {}
# config_dict['mappings']['range_grid'] = range_grid.tolist()
# config_dict['mappings']['angle_grid'] = angle_grid.tolist()
model_configs
=
config_dict
[
'model_cfg'
]
if
model_configs
[
'type'
]
==
'CDC'
:
from
rodnet.models
import
RODNetCDC
as
RODNet
elif
model_configs
[
'type'
]
==
'HG'
:
from
rodnet.models
import
RODNetHG
as
RODNet
elif
model_configs
[
'type'
]
==
'HGwI'
:
from
rodnet.models
import
RODNetHGwI
as
RODNet
else
:
raise
NotImplementedError
# parameter settings
dataset_configs
=
config_dict
[
'dataset_cfg'
]
train_configs
=
config_dict
[
'train_cfg'
]
test_configs
=
config_dict
[
'test_cfg'
]
win_size
=
train_configs
[
'win_size'
]
n_class
=
dataset
.
object_cfg
.
n_class
confmap_shape
=
(
n_class
,
radar_configs
[
'ramap_rsize'
],
radar_configs
[
'ramap_asize'
])
if
'stacked_num'
in
model_configs
:
stacked_num
=
model_configs
[
'stacked_num'
]
else
:
stacked_num
=
None
if
args
.
checkpoint
is
not
None
and
os
.
path
.
exists
(
args
.
checkpoint
):
checkpoint_path
=
args
.
checkpoint
else
:
raise
ValueError
(
"No trained model found."
)
if
args
.
use_noise_channel
:
n_class_test
=
n_class
+
1
else
:
n_class_test
=
n_class
print
(
"Building model ... (%s)"
%
model_configs
)
if
model_configs
[
'type'
]
==
'CDC'
:
rodnet
=
RODNet
(
n_class_test
).
cuda
()
elif
model_configs
[
'type'
]
==
'HG'
:
rodnet
=
RODNet
(
n_class_test
,
stacked_num
=
stacked_num
).
cuda
()
elif
model_configs
[
'type'
]
==
'HGwI'
:
rodnet
=
RODNet
(
n_class_test
,
stacked_num
=
stacked_num
).
cuda
()
else
:
raise
TypeError
checkpoint
=
torch
.
load
(
checkpoint_path
)
if
'optimizer_state_dict'
in
checkpoint
:
rodnet
.
load_state_dict
(
checkpoint
[
'model_state_dict'
])
else
:
rodnet
.
load_state_dict
(
checkpoint
)
if
'model_name'
in
checkpoint
:
model_name
=
checkpoint
[
'model_name'
]
else
:
model_name
=
create_random_model_name
(
model_configs
[
'name'
],
checkpoint_path
)
rodnet
.
eval
()
test_res_dir
=
os
.
path
.
join
(
os
.
path
.
join
(
args
.
res_dir
,
model_name
))
if
not
os
.
path
.
exists
(
test_res_dir
):
os
.
makedirs
(
test_res_dir
)
# save current checkpoint path
weight_log_path
=
os
.
path
.
join
(
test_res_dir
,
'weight_name.txt'
)
if
os
.
path
.
exists
(
weight_log_path
):
with
open
(
weight_log_path
,
'a+'
)
as
f
:
f
.
write
(
checkpoint_path
+
'
\n
'
)
else
:
with
open
(
weight_log_path
,
'w'
)
as
f
:
f
.
write
(
checkpoint_path
+
'
\n
'
)
total_time
=
0
total_count
=
0
data_root
=
dataset_configs
[
'data_root'
]
if
not
args
.
demo
:
seq_names
=
dataset_configs
[
'test'
][
'seqs'
]
else
:
seq_names
=
dataset_configs
[
'demo'
][
'seqs'
]
print
(
seq_names
)
for
seq_name
in
seq_names
:
seq_res_dir
=
os
.
path
.
join
(
test_res_dir
,
seq_name
)
if
not
os
.
path
.
exists
(
seq_res_dir
):
os
.
makedirs
(
seq_res_dir
)
seq_res_viz_dir
=
os
.
path
.
join
(
seq_res_dir
,
'rod_viz'
)
if
not
os
.
path
.
exists
(
seq_res_viz_dir
):
os
.
makedirs
(
seq_res_viz_dir
)
f
=
open
(
os
.
path
.
join
(
seq_res_dir
,
'rod_res.txt'
),
'w'
)
f
.
close
()
for
subset
in
seq_names
:
print
(
subset
)
if
not
args
.
demo
:
crdata_test
=
CRDataset
(
data_dir
=
args
.
data_dir
,
dataset
=
dataset
,
config_dict
=
config_dict
,
split
=
'test'
,
noise_channel
=
args
.
use_noise_channel
,
subset
=
subset
,
is_random_chirp
=
False
)
else
:
crdata_test
=
CRDataset
(
data_dir
=
args
.
data_dir
,
dataset
=
dataset
,
config_dict
=
config_dict
,
split
=
'demo'
,
noise_channel
=
args
.
use_noise_channel
,
subset
=
subset
,
is_random_chirp
=
False
)
print
(
"Length of testing data: %d"
%
len
(
crdata_test
))
dataloader
=
DataLoader
(
crdata_test
,
batch_size
=
1
,
shuffle
=
False
,
num_workers
=
0
,
collate_fn
=
cr_collate
)
seq_names
=
crdata_test
.
seq_names
index_mapping
=
crdata_test
.
index_mapping
init_genConfmap
=
ConfmapStack
(
confmap_shape
)
iter_
=
init_genConfmap
for
i
in
range
(
train_configs
[
'win_size'
]
-
1
):
while
iter_
.
next
is
not
None
:
iter_
=
iter_
.
next
iter_
.
next
=
ConfmapStack
(
confmap_shape
)
load_tic
=
time
.
time
()
for
iter
,
data_dict
in
enumerate
(
dataloader
):
load_time
=
time
.
time
()
-
load_tic
data
=
data_dict
[
'radar_data'
]
image_paths
=
data_dict
[
'image_paths'
][
0
]
seq_name
=
data_dict
[
'seq_names'
][
0
]
if
not
args
.
demo
:
confmap_gt
=
data_dict
[
'anno'
][
'confmaps'
]
obj_info
=
data_dict
[
'anno'
][
'obj_infos'
]
else
:
confmap_gt
=
None
obj_info
=
None
save_path
=
os
.
path
.
join
(
test_res_dir
,
seq_name
,
'rod_res.txt'
)
start_frame_name
=
image_paths
[
0
].
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
end_frame_name
=
image_paths
[
-
1
].
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
start_frame_id
=
int
(
start_frame_name
)
end_frame_id
=
int
(
end_frame_name
)
print
(
"Testing %s: %s-%s"
%
(
seq_name
,
start_frame_name
,
end_frame_name
))
tic
=
time
.
time
()
confmap_pred
=
rodnet
(
data
.
float
().
cuda
())
if
stacked_num
is
not
None
:
confmap_pred
=
confmap_pred
[
-
1
].
cpu
().
detach
().
numpy
()
# (1, 4, 32, 128, 128)
else
:
confmap_pred
=
confmap_pred
.
cpu
().
detach
().
numpy
()
if
args
.
use_noise_channel
:
confmap_pred
=
confmap_pred
[:,
:
n_class
,
:,
:,
:]
infer_time
=
time
.
time
()
-
tic
total_time
+=
infer_time
iter_
=
init_genConfmap
for
i
in
range
(
confmap_pred
.
shape
[
2
]):
if
iter_
.
next
is
None
and
i
!=
confmap_pred
.
shape
[
2
]
-
1
:
iter_
.
next
=
ConfmapStack
(
confmap_shape
)
iter_
.
append
(
confmap_pred
[
0
,
:,
i
,
:,
:])
iter_
=
iter_
.
next
process_tic
=
time
.
time
()
for
i
in
range
(
test_configs
[
'test_stride'
]):
total_count
+=
1
res_final
=
post_process_single_frame
(
init_genConfmap
.
confmap
,
dataset
,
config_dict
)
cur_frame_id
=
start_frame_id
+
i
write_dets_results_single_frame
(
res_final
,
cur_frame_id
,
save_path
,
dataset
)
confmap_pred_0
=
init_genConfmap
.
confmap
res_final_0
=
res_final
img_path
=
image_paths
[
i
]
radar_input
=
chirp_amp
(
data
.
numpy
()[
0
,
:,
i
,
:,
:],
radar_configs
[
'data_type'
])
fig_name
=
os
.
path
.
join
(
test_res_dir
,
seq_name
,
'rod_viz'
,
'%010d.jpg'
%
(
cur_frame_id
))
if
confmap_gt
is
not
None
:
confmap_gt_0
=
confmap_gt
[
0
,
:,
i
,
:,
:]
visualize_test_img
(
fig_name
,
img_path
,
radar_input
,
confmap_pred_0
,
confmap_gt_0
,
res_final_0
,
dataset
,
sybl
=
sybl
)
else
:
visualize_test_img_wo_gt
(
fig_name
,
img_path
,
radar_input
,
confmap_pred_0
,
res_final_0
,
dataset
,
sybl
=
sybl
)
init_genConfmap
=
init_genConfmap
.
next
if
iter
==
len
(
dataloader
)
-
1
:
offset
=
test_configs
[
'test_stride'
]
cur_frame_id
=
start_frame_id
+
offset
while
init_genConfmap
is
not
None
:
total_count
+=
1
res_final
=
post_process_single_frame
(
init_genConfmap
.
confmap
,
dataset
,
config_dict
)
write_dets_results_single_frame
(
res_final
,
cur_frame_id
,
save_path
,
dataset
)
confmap_pred_0
=
init_genConfmap
.
confmap
res_final_0
=
res_final
img_path
=
image_paths
[
offset
]
radar_input
=
chirp_amp
(
data
.
numpy
()[
0
,
:,
offset
,
:,
:],
radar_configs
[
'data_type'
])
fig_name
=
os
.
path
.
join
(
test_res_dir
,
seq_name
,
'rod_viz'
,
'%010d.jpg'
%
(
cur_frame_id
))
if
confmap_gt
is
not
None
:
confmap_gt_0
=
confmap_gt
[
0
,
:,
offset
,
:,
:]
visualize_test_img
(
fig_name
,
img_path
,
radar_input
,
confmap_pred_0
,
confmap_gt_0
,
res_final_0
,
dataset
,
sybl
=
sybl
)
else
:
visualize_test_img_wo_gt
(
fig_name
,
img_path
,
radar_input
,
confmap_pred_0
,
res_final_0
,
dataset
,
sybl
=
sybl
)
init_genConfmap
=
init_genConfmap
.
next
offset
+=
1
cur_frame_id
+=
1
if
init_genConfmap
is
None
:
init_genConfmap
=
ConfmapStack
(
confmap_shape
)
proc_time
=
time
.
time
()
-
process_tic
print
(
"Load time: %.4f | Inference time: %.4f | Process time: %.4f"
%
(
load_time
,
infer_time
,
proc_time
))
load_tic
=
time
.
time
()
print
(
"ave time: %f"
%
(
total_time
/
total_count
))
tools/train.py
0 → 100644
View file @
a8863510
import
os
import
time
import
json
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torch.optim.lr_scheduler
import
StepLR
from
torch.utils.data
import
DataLoader
from
torch.utils.tensorboard
import
SummaryWriter
from
cruw.cruw
import
CRUW
from
rodnet.datasets.CRDataset
import
CRDataset
from
rodnet.datasets.CRDatasetSM
import
CRDatasetSM
from
rodnet.datasets.CRDataLoader
import
CRDataLoader
from
rodnet.datasets.collate_functions
import
cr_collate
from
rodnet.core.radar_processing
import
chirp_amp
from
rodnet.utils.solve_dir
import
create_dir_for_new_model
from
rodnet.utils.load_configs
import
load_configs_from_file
from
rodnet.utils.visualization
import
visualize_train_img
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train RODNet.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'configuration file path'
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
'./data/'
,
help
=
'directory to the prepared data'
)
parser
.
add_argument
(
'--log_dir'
,
type
=
str
,
default
=
'./checkpoints/'
,
help
=
'directory to save trained model'
)
parser
.
add_argument
(
'--resume_from'
,
type
=
str
,
default
=
None
,
help
=
'path to the trained model'
)
parser
.
add_argument
(
'--save_memory'
,
action
=
"store_true"
,
help
=
"use customized dataloader to save memory"
)
parser
.
add_argument
(
'--use_noise_channel'
,
action
=
"store_true"
,
help
=
"use noise channel or not"
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
"__main__"
:
args
=
parse_args
()
config_dict
=
load_configs_from_file
(
args
.
config
)
dataset
=
CRUW
(
data_root
=
config_dict
[
'dataset_cfg'
][
'base_root'
])
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
range_grid
=
dataset
.
range_grid
angle_grid
=
dataset
.
angle_grid
# config_dict['mappings'] = {}
# config_dict['mappings']['range_grid'] = range_grid.tolist()
# config_dict['mappings']['angle_grid'] = angle_grid.tolist()
model_cfg
=
config_dict
[
'model_cfg'
]
if
model_cfg
[
'type'
]
==
'CDC'
:
from
rodnet.models
import
RODNetCDC
as
RODNet
elif
model_cfg
[
'type'
]
==
'HG'
:
from
rodnet.models
import
RODNetHG
as
RODNet
elif
model_cfg
[
'type'
]
==
'HGwI'
:
from
rodnet.models
import
RODNetHGwI
as
RODNet
else
:
raise
NotImplementedError
if
not
os
.
path
.
exists
(
args
.
log_dir
):
os
.
makedirs
(
args
.
log_dir
)
train_model_path
=
args
.
log_dir
# create / load models
cp_path
=
None
epoch_start
=
0
iter_start
=
0
if
args
.
resume_from
is
not
None
and
os
.
path
.
exists
(
args
.
resume_from
):
cp_path
=
args
.
resume_from
model_dir
,
model_name
=
create_dir_for_new_model
(
model_cfg
[
'name'
],
train_model_path
)
else
:
model_dir
,
model_name
=
create_dir_for_new_model
(
model_cfg
[
'name'
],
train_model_path
)
train_viz_path
=
os
.
path
.
join
(
model_dir
,
'train_viz'
)
if
not
os
.
path
.
exists
(
train_viz_path
):
os
.
makedirs
(
train_viz_path
)
writer
=
SummaryWriter
(
model_dir
)
save_config_dict
=
{
'args'
:
vars
(
args
),
'config_dict'
:
config_dict
,
}
config_json_name
=
os
.
path
.
join
(
model_dir
,
'config-'
+
time
.
strftime
(
"%Y%m%d-%H%M%S"
)
+
'.json'
)
with
open
(
config_json_name
,
'w'
)
as
fp
:
json
.
dump
(
save_config_dict
,
fp
)
train_log_name
=
os
.
path
.
join
(
model_dir
,
"train.log"
)
with
open
(
train_log_name
,
'w'
):
pass
n_class
=
dataset
.
object_cfg
.
n_class
n_epoch
=
config_dict
[
'train_cfg'
][
'n_epoch'
]
batch_size
=
config_dict
[
'train_cfg'
][
'batch_size'
]
lr
=
config_dict
[
'train_cfg'
][
'lr'
]
if
'stacked_num'
in
model_cfg
:
stacked_num
=
model_cfg
[
'stacked_num'
]
else
:
stacked_num
=
None
print
(
"Building dataloader ... (Mode: %s)"
%
(
"save_memory"
if
args
.
save_memory
else
"normal"
))
if
not
args
.
save_memory
:
crdata_train
=
CRDataset
(
data_dir
=
args
.
data_dir
,
dataset
=
dataset
,
config_dict
=
config_dict
,
split
=
'train'
,
noise_channel
=
args
.
use_noise_channel
)
seq_names
=
crdata_train
.
seq_names
index_mapping
=
crdata_train
.
index_mapping
dataloader
=
DataLoader
(
crdata_train
,
batch_size
,
shuffle
=
True
,
num_workers
=
0
,
collate_fn
=
cr_collate
)
# crdata_valid = CRDataset(os.path.join(args.data_dir, 'data_details'),
# os.path.join(args.data_dir, 'confmaps_gt'),
# win_size=win_size, set_type='valid', stride=8)
# seq_names_valid = crdata_valid.seq_names
# index_mapping_valid = crdata_valid.index_mapping
# dataloader_valid = DataLoader(crdata_valid, batch_size=batch_size, shuffle=True, num_workers=0)
else
:
crdata_train
=
CRDatasetSM
(
data_root
=
args
.
data_dir
,
config_dict
=
config_dict
,
split
=
'train'
,
noise_channel
=
args
.
use_noise_channel
)
seq_names
=
crdata_train
.
seq_names
index_mapping
=
crdata_train
.
index_mapping
dataloader
=
CRDataLoader
(
crdata_train
,
shuffle
=
True
,
noise_channel
=
args
.
use_noise_channel
)
# crdata_valid = CRDatasetSM(os.path.join(args.data_dir, 'data_details'),
# os.path.join(args.data_dir, 'confmaps_gt'),
# win_size=win_size, set_type='train', stride=8, is_Memory_Limit=True)
# seq_names_valid = crdata_valid.seq_names
# index_mapping_valid = crdata_valid.index_mapping
# dataloader_valid = CRDataLoader(crdata_valid, batch_size=batch_size, shuffle=True)
if
args
.
use_noise_channel
:
n_class_train
=
n_class
+
1
else
:
n_class_train
=
n_class
print
(
"Building model ... (%s)"
%
model_cfg
)
if
model_cfg
[
'type'
]
==
'CDC'
:
if
'mnet_cfg'
in
model_cfg
:
rodnet
=
RODNet
(
n_class_train
,
mnet_cfg
=
model_cfg
[
'mnet_cfg'
]).
cuda
()
else
:
rodnet
=
RODNet
(
n_class_train
).
cuda
()
criterion
=
nn
.
MSELoss
()
elif
model_cfg
[
'type'
]
==
'HG'
:
if
'mnet_cfg'
in
model_cfg
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
,
mnet_cfg
=
model_cfg
[
'mnet_cfg'
]).
cuda
()
else
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
criterion
=
nn
.
BCELoss
()
elif
model_cfg
[
'type'
]
==
'HGwI'
:
if
'mnet_cfg'
in
model_cfg
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
,
mnet_cfg
=
model_cfg
[
'mnet_cfg'
]).
cuda
()
else
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
criterion
=
nn
.
BCELoss
()
else
:
raise
TypeError
# criterion = FocalLoss(focusing_param=8, balance_param=0.25)
optimizer
=
optim
.
Adam
(
rodnet
.
parameters
(),
lr
=
lr
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
config_dict
[
'train_cfg'
][
'lr_step'
],
gamma
=
0.1
)
iter_count
=
0
if
cp_path
is
not
None
:
checkpoint
=
torch
.
load
(
cp_path
)
if
'optimizer_state_dict'
in
checkpoint
:
rodnet
.
load_state_dict
(
checkpoint
[
'model_state_dict'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer_state_dict'
])
epoch_start
=
checkpoint
[
'epoch'
]
+
1
iter_start
=
checkpoint
[
'iter'
]
+
1
loss_cp
=
checkpoint
[
'loss'
]
if
'iter_count'
in
checkpoint
:
iter_count
=
checkpoint
[
'iter_count'
]
else
:
rodnet
.
load_state_dict
(
checkpoint
)
# print training configurations
print
(
"Model name: %s"
%
model_name
)
print
(
"Number of sequences to train: %d"
%
crdata_train
.
n_seq
)
print
(
"Training dataset length: %d"
%
len
(
crdata_train
))
print
(
"Batch size: %d"
%
batch_size
)
print
(
"Number of iterations in each epoch: %d"
%
int
(
len
(
crdata_train
)
/
batch_size
))
for
epoch
in
range
(
epoch_start
,
n_epoch
):
tic_load
=
time
.
time
()
# if epoch == epoch_start:
# dataloader_start = iter_start
# else:
# dataloader_start = 0
for
iter
,
data_dict
in
enumerate
(
dataloader
):
data
=
data_dict
[
'radar_data'
]
image_paths
=
data_dict
[
'image_paths'
]
confmap_gt
=
data_dict
[
'anno'
][
'confmaps'
]
if
not
data_dict
[
'status'
]:
# in case load npy fail
print
(
"Warning: Loading NPY data failed! Skip this iteration"
)
tic_load
=
time
.
time
()
continue
tic
=
time
.
time
()
optimizer
.
zero_grad
()
# zero the parameter gradients
confmap_preds
=
rodnet
(
data
.
float
().
cuda
())
loss_confmap
=
0
if
stacked_num
is
not
None
:
for
i
in
range
(
stacked_num
):
loss_cur
=
criterion
(
confmap_preds
[
i
],
confmap_gt
.
float
().
cuda
())
loss_confmap
+=
loss_cur
loss_confmap
.
backward
()
optimizer
.
step
()
else
:
loss_confmap
=
criterion
(
confmap_preds
,
confmap_gt
.
float
().
cuda
())
loss_confmap
.
backward
()
optimizer
.
step
()
if
iter
%
config_dict
[
'train_cfg'
][
'log_step'
]
==
0
:
# print statistics
print
(
'epoch %2d, iter %4d: loss: %.8f | load time: %.4f | backward time: %.4f'
%
(
epoch
+
1
,
iter
+
1
,
loss_confmap
.
item
(),
tic
-
tic_load
,
time
.
time
()
-
tic
))
with
open
(
train_log_name
,
'a+'
)
as
f_log
:
f_log
.
write
(
'epoch %2d, iter %4d: loss: %.8f | load time: %.4f | backward time: %.4f
\n
'
%
(
epoch
+
1
,
iter
+
1
,
loss_confmap
.
item
(),
tic
-
tic_load
,
time
.
time
()
-
tic
))
if
stacked_num
is
not
None
:
writer
.
add_scalar
(
'loss/loss_all'
,
loss_confmap
.
item
(),
iter_count
)
confmap_pred
=
confmap_preds
[
stacked_num
-
1
].
cpu
().
detach
().
numpy
()
else
:
writer
.
add_scalar
(
'loss/loss_all'
,
loss_confmap
.
item
(),
iter_count
)
confmap_pred
=
confmap_preds
.
cpu
().
detach
().
numpy
()
if
'mnet_cfg'
in
model_cfg
:
chirp_amp_curr
=
chirp_amp
(
data
.
numpy
()[
0
,
:,
0
,
0
,
:,
:],
radar_configs
[
'data_type'
])
else
:
chirp_amp_curr
=
chirp_amp
(
data
.
numpy
()[
0
,
:,
0
,
:,
:],
radar_configs
[
'data_type'
])
if
True
:
# draw train images
fig_name
=
os
.
path
.
join
(
train_viz_path
,
'%03d_%010d_%06d.png'
%
(
epoch
+
1
,
iter_count
,
iter
+
1
))
img_path
=
image_paths
[
0
][
0
]
visualize_train_img
(
fig_name
,
img_path
,
chirp_amp_curr
,
confmap_pred
[
0
,
:
n_class
,
0
,
:,
:],
confmap_gt
[
0
,
:
n_class
,
0
,
:,
:])
else
:
writer
.
add_image
(
'images/ramap'
,
heatmap2rgb
(
chirp_amp_curr
),
iter_count
)
writer
.
add_image
(
'images/confmap_pred'
,
prob2image
(
confmap_pred
[
0
,
:,
0
,
:,
:]),
iter_count
)
writer
.
add_image
(
'images/confmap_gt'
,
prob2image
(
confmap_gt
[
0
,
:,
0
,
:,
:]),
iter_count
)
# TODO: combine three images together
# writer.add_images('')
if
(
iter
+
1
)
%
config_dict
[
'train_cfg'
][
'save_step'
]
==
0
:
# validate current model
# print("validing current model ...")
# validate()
# save current model
print
(
"saving current model ..."
)
status_dict
=
{
'model_name'
:
model_name
,
'epoch'
:
epoch
,
'iter'
:
iter
,
'model_state_dict'
:
rodnet
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'loss'
:
loss_confmap
,
'iter_count'
:
iter_count
,
}
save_model_path
=
'%s/epoch_%02d_iter_%010d.pkl'
%
(
model_dir
,
epoch
+
1
,
iter_count
+
1
)
torch
.
save
(
status_dict
,
save_model_path
)
iter_count
+=
1
tic_load
=
time
.
time
()
# save current model
print
(
"saving current epoch model ..."
)
status_dict
=
{
'model_name'
:
model_name
,
'epoch'
:
epoch
,
'iter'
:
iter
,
'model_state_dict'
:
rodnet
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'loss'
:
loss_confmap
,
'iter_count'
:
iter_count
,
}
save_model_path
=
'%s/epoch_%02d_final.pkl'
%
(
model_dir
,
epoch
+
1
)
torch
.
save
(
status_dict
,
save_model_path
)
scheduler
.
step
()
print
(
'Training Finished.'
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment