Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
RODNet
Commits
1d3dead7
Commit
1d3dead7
authored
Dec 30, 2020
by
Yizhou Wang
Browse files
update base code for ROD2021
parent
81f1e0ac
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
109 additions
and
79 deletions
+109
-79
rodnet/core/radar_processing/chirp_ops.py
rodnet/core/radar_processing/chirp_ops.py
+1
-1
rodnet/datasets/CRDataset.py
rodnet/datasets/CRDataset.py
+11
-10
rodnet/datasets/loaders/__init__.py
rodnet/datasets/loaders/__init__.py
+1
-1
rodnet/datasets/loaders/parse_pkl.py
rodnet/datasets/loaders/parse_pkl.py
+5
-0
tools/prepare_dataset/prepare_data.py
tools/prepare_dataset/prepare_data.py
+78
-32
tools/train.py
tools/train.py
+13
-35
No files found.
rodnet/core/radar_processing/chirp_ops.py
View file @
1d3dead7
...
@@ -9,7 +9,7 @@ def chirp_amp(chirp, radar_data_type):
...
@@ -9,7 +9,7 @@ def chirp_amp(chirp, radar_data_type):
:return: amplitude map for the input chirp (w x h)
:return: amplitude map for the input chirp (w x h)
"""
"""
c0
,
c1
,
c2
=
chirp
.
shape
c0
,
c1
,
c2
=
chirp
.
shape
if
radar_data_type
==
'RI'
or
radar_data_type
==
'RISEP'
:
if
radar_data_type
==
'RI'
or
radar_data_type
==
'RISEP'
or
radar_data_type
==
'ROD2021'
:
if
c0
==
2
:
if
c0
==
2
:
chirp_abs
=
np
.
sqrt
(
chirp
[
0
,
:,
:]
**
2
+
chirp
[
1
,
:,
:]
**
2
)
chirp_abs
=
np
.
sqrt
(
chirp
[
0
,
:,
:]
**
2
+
chirp
[
1
,
:,
:]
**
2
)
elif
c2
==
2
:
elif
c2
==
2
:
...
...
rodnet/datasets/CRDataset.py
View file @
1d3dead7
...
@@ -7,7 +7,7 @@ from tqdm import tqdm
...
@@ -7,7 +7,7 @@ from tqdm import tqdm
from
torch.utils
import
data
from
torch.utils
import
data
from
.loaders
import
list_pkl_filenames
from
.loaders
import
list_pkl_filenames
,
list_pkl_filenames_from_prepared
class
CRDataset
(
data
.
Dataset
):
class
CRDataset
(
data
.
Dataset
):
...
@@ -61,7 +61,8 @@ class CRDataset(data.Dataset):
...
@@ -61,7 +61,8 @@ class CRDataset(data.Dataset):
if
subset
is
not
None
:
if
subset
is
not
None
:
self
.
data_files
=
[
subset
+
'.pkl'
]
self
.
data_files
=
[
subset
+
'.pkl'
]
else
:
else
:
self
.
data_files
=
list_pkl_filenames
(
config_dict
[
'dataset_cfg'
],
split
)
# self.data_files = list_pkl_filenames(config_dict['dataset_cfg'], split)
self
.
data_files
=
list_pkl_filenames_from_prepared
(
data_dir
,
split
)
self
.
seq_names
=
[
name
.
split
(
'.'
)[
0
]
for
name
in
self
.
data_files
]
self
.
seq_names
=
[
name
.
split
(
'.'
)[
0
]
for
name
in
self
.
data_files
]
self
.
n_seq
=
len
(
self
.
seq_names
)
self
.
n_seq
=
len
(
self
.
seq_names
)
...
@@ -142,8 +143,15 @@ class CRDataset(data.Dataset):
...
@@ -142,8 +143,15 @@ class CRDataset(data.Dataset):
data_dict
[
'image_paths'
].
append
(
image_paths
[
frameid
])
data_dict
[
'image_paths'
].
append
(
image_paths
[
frameid
])
else
:
else
:
raise
TypeError
raise
TypeError
elif
radar_configs
[
'data_type'
]
==
'ROD2021'
:
radar_npy_win
=
np
.
zeros
((
self
.
win_size
,
ramap_rsize
,
ramap_asize
,
2
),
dtype
=
np
.
float32
)
chirp_id
=
0
# only use chirp 0 for training
for
idx
,
frameid
in
enumerate
(
range
(
data_id
,
data_id
+
self
.
win_size
*
self
.
step
,
self
.
step
)):
radar_npy_win
[
idx
,
:,
:,
:]
=
np
.
load
(
radar_paths
[
frameid
][
chirp_id
])
data_dict
[
'image_paths'
].
append
(
image_paths
[
frameid
])
else
:
else
:
raise
Value
Error
raise
NotImplemented
Error
except
:
except
:
# in case load npy fail
# in case load npy fail
data_dict
[
'status'
]
=
False
data_dict
[
'status'
]
=
False
...
@@ -202,10 +210,3 @@ class CRDataset(data.Dataset):
...
@@ -202,10 +210,3 @@ class CRDataset(data.Dataset):
data_dict
[
'anno'
]
=
None
data_dict
[
'anno'
]
=
None
return
data_dict
return
data_dict
if
__name__
==
"__main__"
:
dataset
=
CRDataset
(
'./data/data_details'
,
stride
=
16
)
print
(
len
(
dataset
))
for
i
in
range
(
len
(
dataset
)):
continue
rodnet/datasets/loaders/__init__.py
View file @
1d3dead7
from
.parse_pkl
import
list_pkl_filenames
from
.parse_pkl
import
list_pkl_filenames
,
list_pkl_filenames_from_prepared
from
.read_rod_results
import
load_rodnet_res
,
load_vgg_res
from
.read_rod_results
import
load_rodnet_res
,
load_vgg_res
rodnet/datasets/loaders/parse_pkl.py
View file @
1d3dead7
...
@@ -6,3 +6,8 @@ def list_pkl_filenames(dataset_configs, split):
...
@@ -6,3 +6,8 @@ def list_pkl_filenames(dataset_configs, split):
seqs
=
dataset_configs
[
split
][
'seqs'
]
seqs
=
dataset_configs
[
split
][
'seqs'
]
seqs_pkl_names
=
[
name
+
'.pkl'
for
name
in
seqs
]
seqs_pkl_names
=
[
name
+
'.pkl'
for
name
in
seqs
]
return
seqs_pkl_names
return
seqs_pkl_names
def
list_pkl_filenames_from_prepared
(
data_dir
,
split
):
seqs_pkl_names
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
data_dir
,
split
)))
return
seqs_pkl_names
tools/prepare_dataset/prepare_data.py
View file @
1d3dead7
...
@@ -6,7 +6,9 @@ import json
...
@@ -6,7 +6,9 @@ import json
import
pickle
import
pickle
import
argparse
import
argparse
from
cruw.cruw
import
CRUW
from
cruw
import
CRUW
from
cruw.annotation.init_json
import
init_meta_json
from
cruw.mapping
import
ra2idx
from
rodnet.core.confidence_map
import
generate_confmap
,
normalize_confmap
,
add_noise_channel
from
rodnet.core.confidence_map
import
generate_confmap
,
normalize_confmap
,
add_noise_channel
from
rodnet.utils.load_configs
import
load_configs_from_file
from
rodnet.utils.load_configs
import
load_configs_from_file
...
@@ -27,6 +29,52 @@ def parse_args():
...
@@ -27,6 +29,52 @@ def parse_args():
return
args
return
args
def
load_anno_txt
(
txt_path
,
n_frame
,
dataset
):
folder_name_dict
=
dict
(
cam_0
=
'IMAGES_0'
,
rad_h
=
'RADAR_RA_H'
)
anno_dict
=
init_meta_json
(
n_frame
,
folder_name_dict
)
with
open
(
txt_path
,
'r'
)
as
f
:
data
=
f
.
readlines
()
for
line
in
data
:
frame_id
,
r
,
a
,
class_name
=
line
.
rstrip
().
split
()
frame_id
=
int
(
frame_id
)
r
=
float
(
r
)
a
=
float
(
a
)
rid
,
aid
=
ra2idx
(
r
,
a
,
dataset
.
range_grid
,
dataset
.
angle_grid
)
anno_dict
[
frame_id
][
'rad_h'
][
'n_objects'
]
+=
1
anno_dict
[
frame_id
][
'rad_h'
][
'obj_info'
][
'categories'
].
append
(
class_name
)
anno_dict
[
frame_id
][
'rad_h'
][
'obj_info'
][
'centers'
].
append
([
r
,
a
])
anno_dict
[
frame_id
][
'rad_h'
][
'obj_info'
][
'center_ids'
].
append
([
rid
,
aid
])
anno_dict
[
frame_id
][
'rad_h'
][
'obj_info'
][
'scores'
].
append
(
1.0
)
return
anno_dict
def
generate_confmaps
(
metadata_dict
,
n_class
,
viz
):
confmaps
=
[]
for
metadata_frame
in
metadata_dict
:
n_obj
=
metadata_frame
[
'rad_h'
][
'n_objects'
]
obj_info
=
metadata_frame
[
'rad_h'
][
'obj_info'
]
if
n_obj
==
0
:
confmap_gt
=
np
.
zeros
(
(
n_class
+
1
,
radar_configs
[
'ramap_rsize'
],
radar_configs
[
'ramap_asize'
]),
dtype
=
float
)
confmap_gt
[
-
1
,
:,
:]
=
1.0
# initialize noise channal
else
:
confmap_gt
=
generate_confmap
(
n_obj
,
obj_info
,
dataset
,
config_dict
)
confmap_gt
=
normalize_confmap
(
confmap_gt
)
confmap_gt
=
add_noise_channel
(
confmap_gt
,
dataset
,
config_dict
)
assert
confmap_gt
.
shape
==
(
n_class
+
1
,
radar_configs
[
'ramap_rsize'
],
radar_configs
[
'ramap_asize'
])
if
viz
:
visualize_confmap
(
confmap_gt
)
confmaps
.
append
(
confmap_gt
)
confmaps
=
np
.
array
(
confmaps
)
return
confmaps
def
prepare_data
(
dataset
,
config_dict
,
data_dir
,
split
,
save_dir
,
viz
=
False
,
overwrite
=
False
):
def
prepare_data
(
dataset
,
config_dict
,
data_dir
,
split
,
save_dir
,
viz
=
False
,
overwrite
=
False
):
"""
"""
Prepare pickle data for RODNet training and testing
Prepare pickle data for RODNet training and testing
...
@@ -34,6 +82,7 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
...
@@ -34,6 +82,7 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
:param config_dict: rodnet configurations
:param config_dict: rodnet configurations
:param data_dir: output directory of the processed data
:param data_dir: output directory of the processed data
:param split: train, valid, test, demo, etc.
:param split: train, valid, test, demo, etc.
:param save_dir: output directory of the prepared data
:param viz: whether visualize the prepared data
:param viz: whether visualize the prepared data
:param overwrite: whether overwrite the existing prepared data
:param overwrite: whether overwrite the existing prepared data
:return:
:return:
...
@@ -46,6 +95,9 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
...
@@ -46,6 +95,9 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
data_root
=
config_dict
[
'dataset_cfg'
][
'data_root'
]
data_root
=
config_dict
[
'dataset_cfg'
][
'data_root'
]
anno_root
=
config_dict
[
'dataset_cfg'
][
'anno_root'
]
anno_root
=
config_dict
[
'dataset_cfg'
][
'anno_root'
]
set_cfg
=
config_dict
[
'dataset_cfg'
][
split
]
set_cfg
=
config_dict
[
'dataset_cfg'
][
split
]
if
'seqs'
not
in
set_cfg
:
sets_seqs
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
data_root
,
set_cfg
[
'subdir'
])))
else
:
sets_seqs
=
set_cfg
[
'seqs'
]
sets_seqs
=
set_cfg
[
'seqs'
]
if
overwrite
:
if
overwrite
:
...
@@ -54,8 +106,8 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
...
@@ -54,8 +106,8 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
os
.
makedirs
(
os
.
path
.
join
(
data_dir
,
split
))
os
.
makedirs
(
os
.
path
.
join
(
data_dir
,
split
))
for
seq
in
sets_seqs
:
for
seq
in
sets_seqs
:
seq_path
=
os
.
path
.
join
(
data_root
,
seq
)
seq_path
=
os
.
path
.
join
(
data_root
,
set_cfg
[
'subdir'
],
seq
)
seq_anno_path
=
os
.
path
.
join
(
anno_root
,
se
q
+
'.json'
)
seq_anno_path
=
os
.
path
.
join
(
anno_root
,
se
t_cfg
[
'subdir'
],
seq
+
config_dict
[
'dataset_cfg'
][
'anno_ext'
]
)
save_path
=
os
.
path
.
join
(
save_dir
,
seq
+
'.pkl'
)
save_path
=
os
.
path
.
join
(
save_dir
,
seq
+
'.pkl'
)
print
(
"Sequence %s saving to %s"
%
(
seq_path
,
save_path
))
print
(
"Sequence %s saving to %s"
%
(
seq_path
,
save_path
))
...
@@ -89,6 +141,16 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
...
@@ -89,6 +141,16 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
for
chirp_id
in
range
(
n_chirp
):
for
chirp_id
in
range
(
n_chirp
):
frame_paths
.
append
(
radar_paths_chirp
[
chirp_id
][
frame_id
])
frame_paths
.
append
(
radar_paths_chirp
[
chirp_id
][
frame_id
])
radar_paths
.
append
(
frame_paths
)
radar_paths
.
append
(
frame_paths
)
elif
radar_configs
[
'data_type'
]
==
'ROD2021'
:
assert
len
(
os
.
listdir
(
radar_dir
))
==
n_frame
*
len
(
radar_configs
[
'chirp_ids'
])
radar_paths
=
[]
for
frame_id
in
range
(
n_frame
):
chirp_paths
=
[]
for
chirp_id
in
radar_configs
[
'chirp_ids'
]:
path
=
os
.
path
.
join
(
radar_dir
,
'%06d_%04d.'
%
(
frame_id
,
chirp_id
)
+
dataset
.
sensor_cfg
.
radar_cfg
[
'ext'
])
chirp_paths
.
append
(
path
)
radar_paths
.
append
(
chirp_paths
)
else
:
else
:
raise
ValueError
raise
ValueError
...
@@ -107,35 +169,19 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
...
@@ -107,35 +169,19 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
pickle
.
dump
(
data_dict
,
open
(
save_path
,
'wb'
))
pickle
.
dump
(
data_dict
,
open
(
save_path
,
'wb'
))
continue
continue
else
:
else
:
anno_obj
=
{}
if
config_dict
[
'dataset_cfg'
][
'anno_ext'
]
==
'.txt'
:
anno_obj
[
'metadata'
]
=
load_anno_txt
(
seq_anno_path
,
n_frame
,
dataset
)
elif
config_dict
[
'dataset_cfg'
][
'anno_ext'
]
==
'.json'
:
with
open
(
os
.
path
.
join
(
seq_anno_path
),
'r'
)
as
f
:
with
open
(
os
.
path
.
join
(
seq_anno_path
),
'r'
)
as
f
:
anno
=
json
.
load
(
f
)
anno
=
json
.
load
(
f
)
anno_obj
=
{}
anno_obj
[
'metadata'
]
=
anno
[
'metadata'
]
anno_obj
[
'metadata'
]
=
anno
[
'metadata'
]
anno_obj
[
'confmaps'
]
=
[]
for
metadata_frame
in
anno
[
'metadata'
]:
n_obj
=
metadata_frame
[
'rad_h'
][
'n_objects'
]
obj_info
=
metadata_frame
[
'rad_h'
][
'obj_info'
]
if
n_obj
==
0
:
confmap_gt
=
np
.
zeros
(
(
n_class
+
1
,
radar_configs
[
'ramap_rsize'
],
radar_configs
[
'ramap_asize'
]),
dtype
=
float
)
confmap_gt
[
-
1
,
:,
:]
=
1.0
# initialize noise channal
else
:
else
:
confmap_gt
=
generate_confmap
(
n_obj
,
obj_info
,
dataset
,
config_dict
)
raise
confmap_gt
=
normalize_confmap
(
confmap_gt
)
confmap_gt
=
add_noise_channel
(
confmap_gt
,
dataset
,
config_dict
)
assert
confmap_gt
.
shape
==
(
n_class
+
1
,
radar_configs
[
'ramap_rsize'
],
radar_configs
[
'ramap_asize'
])
if
viz
:
visualize_confmap
(
confmap_gt
)
anno_obj
[
'confmaps'
].
append
(
confmap_gt
)
# end objects loop
anno_obj
[
'confmaps'
]
=
np
.
array
(
anno_obj
[
'confmaps'
]
)
anno_obj
[
'confmaps'
]
=
generate_confmaps
(
anno_obj
[
'metadata'
],
n_class
,
viz
)
data_dict
[
'anno'
]
=
anno_obj
data_dict
[
'anno'
]
=
anno_obj
# save pkl files
# save pkl files
pickle
.
dump
(
data_dict
,
open
(
save_path
,
'wb'
))
pickle
.
dump
(
data_dict
,
open
(
save_path
,
'wb'
))
# end frames loop
# end frames loop
...
@@ -151,7 +197,7 @@ if __name__ == "__main__":
...
@@ -151,7 +197,7 @@ if __name__ == "__main__":
out_data_dir
=
args
.
out_data_dir
out_data_dir
=
args
.
out_data_dir
overwrite
=
args
.
overwrite
overwrite
=
args
.
overwrite
dataset
=
CRUW
(
data_root
=
data_root
)
dataset
=
CRUW
(
data_root
=
data_root
,
sensor_config_name
=
'sensor_config_rod2021'
)
config_dict
=
load_configs_from_file
(
args
.
config
)
config_dict
=
load_configs_from_file
(
args
.
config
)
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
...
...
tools/train.py
View file @
1d3dead7
...
@@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import StepLR
...
@@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import StepLR
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
from
cruw
.cruw
import
CRUW
from
cruw
import
CRUW
from
rodnet.datasets.CRDataset
import
CRDataset
from
rodnet.datasets.CRDataset
import
CRDataset
from
rodnet.datasets.CRDatasetSM
import
CRDatasetSM
from
rodnet.datasets.CRDatasetSM
import
CRDatasetSM
...
@@ -37,16 +37,13 @@ def parse_args():
...
@@ -37,16 +37,13 @@ def parse_args():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
args
=
parse_args
()
config_dict
=
load_configs_from_file
(
args
.
config
)
config_dict
=
load_configs_from_file
(
args
.
config
)
dataset
=
CRUW
(
data_root
=
config_dict
[
'dataset_cfg'
][
'base_root'
])
# dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'])
dataset
=
CRUW
(
data_root
=
config_dict
[
'dataset_cfg'
][
'base_root'
],
sensor_config_name
=
'sensor_config_rod2021'
)
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
range_grid
=
dataset
.
range_grid
range_grid
=
dataset
.
range_grid
angle_grid
=
dataset
.
angle_grid
angle_grid
=
dataset
.
angle_grid
# config_dict['mappings'] = {}
# config_dict['mappings']['range_grid'] = range_grid.tolist()
# config_dict['mappings']['angle_grid'] = angle_grid.tolist()
model_cfg
=
config_dict
[
'model_cfg'
]
model_cfg
=
config_dict
[
'model_cfg'
]
if
model_cfg
[
'type'
]
==
'CDC'
:
if
model_cfg
[
'type'
]
==
'CDC'
:
from
rodnet.models
import
RODNetCDC
as
RODNet
from
rodnet.models
import
RODNetCDC
as
RODNet
elif
model_cfg
[
'type'
]
==
'HG'
:
elif
model_cfg
[
'type'
]
==
'HG'
:
...
@@ -132,27 +129,16 @@ if __name__ == "__main__":
...
@@ -132,27 +129,16 @@ if __name__ == "__main__":
print
(
"Building model ... (%s)"
%
model_cfg
)
print
(
"Building model ... (%s)"
%
model_cfg
)
if
model_cfg
[
'type'
]
==
'CDC'
:
if
model_cfg
[
'type'
]
==
'CDC'
:
if
'mnet_cfg'
in
model_cfg
:
rodnet
=
RODNet
(
n_class_train
,
mnet_cfg
=
model_cfg
[
'mnet_cfg'
]).
cuda
()
else
:
rodnet
=
RODNet
(
n_class_train
).
cuda
()
rodnet
=
RODNet
(
n_class_train
).
cuda
()
criterion
=
nn
.
MSELoss
()
criterion
=
nn
.
MSELoss
()
elif
model_cfg
[
'type'
]
==
'HG'
:
elif
model_cfg
[
'type'
]
==
'HG'
:
if
'mnet_cfg'
in
model_cfg
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
,
mnet_cfg
=
model_cfg
[
'mnet_cfg'
]).
cuda
()
else
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
criterion
=
nn
.
BCELoss
()
criterion
=
nn
.
BCELoss
()
elif
model_cfg
[
'type'
]
==
'HGwI'
:
elif
model_cfg
[
'type'
]
==
'HGwI'
:
if
'mnet_cfg'
in
model_cfg
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
,
mnet_cfg
=
model_cfg
[
'mnet_cfg'
]).
cuda
()
else
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
criterion
=
nn
.
BCELoss
()
criterion
=
nn
.
BCELoss
()
else
:
else
:
raise
TypeError
raise
TypeError
# criterion = FocalLoss(focusing_param=8, balance_param=0.25)
optimizer
=
optim
.
Adam
(
rodnet
.
parameters
(),
lr
=
lr
)
optimizer
=
optim
.
Adam
(
rodnet
.
parameters
(),
lr
=
lr
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
config_dict
[
'train_cfg'
][
'lr_step'
],
gamma
=
0.1
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
config_dict
[
'train_cfg'
][
'lr_step'
],
gamma
=
0.1
)
...
@@ -232,7 +218,6 @@ if __name__ == "__main__":
...
@@ -232,7 +218,6 @@ if __name__ == "__main__":
else
:
else
:
chirp_amp_curr
=
chirp_amp
(
data
.
numpy
()[
0
,
:,
0
,
:,
:],
radar_configs
[
'data_type'
])
chirp_amp_curr
=
chirp_amp
(
data
.
numpy
()[
0
,
:,
0
,
:,
:],
radar_configs
[
'data_type'
])
if
True
:
# draw train images
# draw train images
fig_name
=
os
.
path
.
join
(
train_viz_path
,
fig_name
=
os
.
path
.
join
(
train_viz_path
,
'%03d_%010d_%06d.png'
%
(
epoch
+
1
,
iter_count
,
iter
+
1
))
'%03d_%010d_%06d.png'
%
(
epoch
+
1
,
iter_count
,
iter
+
1
))
...
@@ -240,13 +225,6 @@ if __name__ == "__main__":
...
@@ -240,13 +225,6 @@ if __name__ == "__main__":
visualize_train_img
(
fig_name
,
img_path
,
chirp_amp_curr
,
visualize_train_img
(
fig_name
,
img_path
,
chirp_amp_curr
,
confmap_pred
[
0
,
:
n_class
,
0
,
:,
:],
confmap_pred
[
0
,
:
n_class
,
0
,
:,
:],
confmap_gt
[
0
,
:
n_class
,
0
,
:,
:])
confmap_gt
[
0
,
:
n_class
,
0
,
:,
:])
else
:
writer
.
add_image
(
'images/ramap'
,
heatmap2rgb
(
chirp_amp_curr
),
iter_count
)
writer
.
add_image
(
'images/confmap_pred'
,
prob2image
(
confmap_pred
[
0
,
:,
0
,
:,
:]),
iter_count
)
writer
.
add_image
(
'images/confmap_gt'
,
prob2image
(
confmap_gt
[
0
,
:,
0
,
:,
:]),
iter_count
)
# TODO: combine three images together
# writer.add_images('')
if
(
iter
+
1
)
%
config_dict
[
'train_cfg'
][
'save_step'
]
==
0
:
if
(
iter
+
1
)
%
config_dict
[
'train_cfg'
][
'save_step'
]
==
0
:
# validate current model
# validate current model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment