Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
RODNet
Commits
d0140132
Commit
d0140132
authored
Jan 29, 2022
by
Yizhou Wang
Browse files
v1.1 code for RODNet J-STSP version
parent
9266cc35
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
8 deletions
+34
-8
tools/prepare_dataset/prepare_data.py
tools/prepare_dataset/prepare_data.py
+4
-2
tools/train.py
tools/train.py
+30
-6
No files found.
tools/prepare_dataset/prepare_data.py
View file @
d0140132
...
@@ -11,7 +11,7 @@ from cruw.annotation.init_json import init_meta_json
...
@@ -11,7 +11,7 @@ from cruw.annotation.init_json import init_meta_json
from
cruw.mapping
import
ra2idx
from
cruw.mapping
import
ra2idx
from
rodnet.core.confidence_map
import
generate_confmap
,
normalize_confmap
,
add_noise_channel
from
rodnet.core.confidence_map
import
generate_confmap
,
normalize_confmap
,
add_noise_channel
from
rodnet.utils.load_configs
import
load_configs_from_file
from
rodnet.utils.load_configs
import
load_configs_from_file
,
update_config_dict
from
rodnet.utils.visualization
import
visualize_confmap
from
rodnet.utils.visualization
import
visualize_confmap
SPLITS_LIST
=
[
'train'
,
'valid'
,
'test'
,
'demo'
]
SPLITS_LIST
=
[
'train'
,
'valid'
,
'test'
,
'demo'
]
...
@@ -20,7 +20,8 @@ SPLITS_LIST = ['train', 'valid', 'test', 'demo']
...
@@ -20,7 +20,8 @@ SPLITS_LIST = ['train', 'valid', 'test', 'demo']
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Prepare RODNet data.'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Prepare RODNet data.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
dest
=
'config'
,
help
=
'configuration file path'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
dest
=
'config'
,
help
=
'configuration file path'
)
parser
.
add_argument
(
'--data_root'
,
type
=
str
,
help
=
'directory to the prepared data'
)
parser
.
add_argument
(
'--data_root'
,
type
=
str
,
help
=
'directory to the dataset (will overwrite data_root in config file)'
)
parser
.
add_argument
(
'--sensor_config'
,
type
=
str
,
default
=
'sensor_config_rod2021'
)
parser
.
add_argument
(
'--sensor_config'
,
type
=
str
,
default
=
'sensor_config_rod2021'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
dest
=
'split'
,
default
=
''
,
parser
.
add_argument
(
'--split'
,
type
=
str
,
dest
=
'split'
,
default
=
''
,
help
=
'choose from train, valid, test, supertest'
)
help
=
'choose from train, valid, test, supertest'
)
...
@@ -220,6 +221,7 @@ if __name__ == "__main__":
...
@@ -220,6 +221,7 @@ if __name__ == "__main__":
dataset
=
CRUW
(
data_root
=
data_root
,
sensor_config_name
=
args
.
sensor_config
)
dataset
=
CRUW
(
data_root
=
data_root
,
sensor_config_name
=
args
.
sensor_config
)
config_dict
=
load_configs_from_file
(
args
.
config
)
config_dict
=
load_configs_from_file
(
args
.
config
)
config_dict
=
update_config_dict
(
config_dict
,
args
)
# update configs by args
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
radar_configs
=
dataset
.
sensor_cfg
.
radar_cfg
if
splits
==
None
:
if
splits
==
None
:
...
...
tools/train.py
View file @
d0140132
...
@@ -18,7 +18,7 @@ from rodnet.datasets.CRDataLoader import CRDataLoader
...
@@ -18,7 +18,7 @@ from rodnet.datasets.CRDataLoader import CRDataLoader
from
rodnet.datasets.collate_functions
import
cr_collate
from
rodnet.datasets.collate_functions
import
cr_collate
from
rodnet.core.radar_processing
import
chirp_amp
from
rodnet.core.radar_processing
import
chirp_amp
from
rodnet.utils.solve_dir
import
create_dir_for_new_model
from
rodnet.utils.solve_dir
import
create_dir_for_new_model
from
rodnet.utils.load_configs
import
load_configs_from_file
from
rodnet.utils.load_configs
import
load_configs_from_file
,
update_config_dict
from
rodnet.utils.visualization
import
visualize_train_img
from
rodnet.utils.visualization
import
visualize_train_img
...
@@ -51,6 +51,12 @@ if __name__ == "__main__":
...
@@ -51,6 +51,12 @@ if __name__ == "__main__":
from
rodnet.models
import
RODNetHG
as
RODNet
from
rodnet.models
import
RODNetHG
as
RODNet
elif
model_cfg
[
'type'
]
==
'HGwI'
:
elif
model_cfg
[
'type'
]
==
'HGwI'
:
from
rodnet.models
import
RODNetHGwI
as
RODNet
from
rodnet.models
import
RODNetHGwI
as
RODNet
elif
model_cfg
[
'type'
]
==
'CDCv2'
:
from
rodnet.models
import
RODNetCDCDCN
as
RODNet
elif
model_cfg
[
'type'
]
==
'HGv2'
:
from
rodnet.models
import
RODNetHGDCN
as
RODNet
elif
model_cfg
[
'type'
]
==
'HGwIv2'
:
from
rodnet.models
import
RODNetHGwIDCN
as
RODNet
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -110,7 +116,7 @@ if __name__ == "__main__":
...
@@ -110,7 +116,7 @@ if __name__ == "__main__":
# dataloader_valid = DataLoader(crdata_valid, batch_size=batch_size, shuffle=True, num_workers=0)
# dataloader_valid = DataLoader(crdata_valid, batch_size=batch_size, shuffle=True, num_workers=0)
else
:
else
:
crdata_train
=
CRDatasetSM
(
data_
root
=
args
.
data_dir
,
config_dict
=
config_dict
,
split
=
'train'
,
crdata_train
=
CRDatasetSM
(
data_
dir
=
args
.
data_dir
,
dataset
=
dataset
,
config_dict
=
config_dict
,
split
=
'train'
,
noise_channel
=
args
.
use_noise_channel
)
noise_channel
=
args
.
use_noise_channel
)
seq_names
=
crdata_train
.
seq_names
seq_names
=
crdata_train
.
seq_names
index_mapping
=
crdata_train
.
index_mapping
index_mapping
=
crdata_train
.
index_mapping
...
@@ -130,13 +136,31 @@ if __name__ == "__main__":
...
@@ -130,13 +136,31 @@ if __name__ == "__main__":
print
(
"Building model ... (%s)"
%
model_cfg
)
print
(
"Building model ... (%s)"
%
model_cfg
)
if
model_cfg
[
'type'
]
==
'CDC'
:
if
model_cfg
[
'type'
]
==
'CDC'
:
rodnet
=
RODNet
(
n_class_train
).
cuda
()
rodnet
=
RODNet
(
in_channels
=
2
,
n_class
=
n_class_train
).
cuda
()
criterion
=
nn
.
MS
ELoss
()
criterion
=
nn
.
BC
ELoss
()
elif
model_cfg
[
'type'
]
==
'HG'
:
elif
model_cfg
[
'type'
]
==
'HG'
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
rodnet
=
RODNet
(
in_channels
=
2
,
n_class
=
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
criterion
=
nn
.
BCELoss
()
criterion
=
nn
.
BCELoss
()
elif
model_cfg
[
'type'
]
==
'HGwI'
:
elif
model_cfg
[
'type'
]
==
'HGwI'
:
rodnet
=
RODNet
(
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
rodnet
=
RODNet
(
in_channels
=
2
,
n_class
=
n_class_train
,
stacked_num
=
stacked_num
).
cuda
()
criterion
=
nn
.
BCELoss
()
elif
model_cfg
[
'type'
]
==
'CDCv2'
:
in_chirps
=
len
(
radar_configs
[
'chirp_ids'
])
rodnet
=
RODNet
(
in_channels
=
in_chirps
,
n_class
=
n_class_train
,
mnet_cfg
=
config_dict
[
'model_cfg'
][
'mnet_cfg'
],
dcn
=
config_dict
[
'model_cfg'
][
'dcn'
]).
cuda
()
criterion
=
nn
.
BCELoss
()
elif
model_cfg
[
'type'
]
==
'HGv2'
:
in_chirps
=
len
(
radar_configs
[
'chirp_ids'
])
rodnet
=
RODNet
(
in_channels
=
in_chirps
,
n_class
=
n_class_train
,
stacked_num
=
stacked_num
,
mnet_cfg
=
config_dict
[
'model_cfg'
][
'mnet_cfg'
],
dcn
=
config_dict
[
'model_cfg'
][
'dcn'
]).
cuda
()
criterion
=
nn
.
BCELoss
()
elif
model_cfg
[
'type'
]
==
'HGwIv2'
:
in_chirps
=
len
(
radar_configs
[
'chirp_ids'
])
rodnet
=
RODNet
(
in_channels
=
in_chirps
,
n_class
=
n_class_train
,
stacked_num
=
stacked_num
,
mnet_cfg
=
config_dict
[
'model_cfg'
][
'mnet_cfg'
],
dcn
=
config_dict
[
'model_cfg'
][
'dcn'
]).
cuda
()
criterion
=
nn
.
BCELoss
()
criterion
=
nn
.
BCELoss
()
else
:
else
:
raise
TypeError
raise
TypeError
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment