Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
02ac3e17
Unverified
Commit
02ac3e17
authored
May 13, 2023
by
Shaoshuai Shi
Committed by
GitHub
May 13, 2023
Browse files
Support multi-modal 3D detection on NuScenes #1339
Add support for multi-modal NuScenes Detection
parents
ad9c25c0
fcfa0773
Changes
41
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
2 deletions
+27
-2
tools/train_utils/train_utils.py
tools/train_utils/train_utils.py
+27
-2
No files found.
tools/train_utils/train_utils.py
View file @
02ac3e17
...
@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
...
@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
data_timer
=
time
.
time
()
data_timer
=
time
.
time
()
cur_data_time
=
data_timer
-
end
cur_data_time
=
data_timer
-
end
lr_scheduler
.
step
(
accumulated_iter
)
lr_scheduler
.
step
(
accumulated_iter
,
cur_epoch
)
try
:
try
:
cur_lr
=
float
(
optimizer
.
lr
)
cur_lr
=
float
(
optimizer
.
lr
)
...
@@ -151,8 +151,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
...
@@ -151,8 +151,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
start_epoch
,
total_epochs
,
start_iter
,
rank
,
tb_log
,
ckpt_save_dir
,
train_sampler
=
None
,
start_epoch
,
total_epochs
,
start_iter
,
rank
,
tb_log
,
ckpt_save_dir
,
train_sampler
=
None
,
lr_warmup_scheduler
=
None
,
ckpt_save_interval
=
1
,
max_ckpt_save_num
=
50
,
lr_warmup_scheduler
=
None
,
ckpt_save_interval
=
1
,
max_ckpt_save_num
=
50
,
merge_all_iters_to_one_epoch
=
False
,
use_amp
=
False
,
merge_all_iters_to_one_epoch
=
False
,
use_amp
=
False
,
use_logger_to_record
=
False
,
logger
=
None
,
logger_iter_interval
=
None
,
ckpt_save_time_interval
=
None
,
show_gpu_stat
=
False
):
use_logger_to_record
=
False
,
logger
=
None
,
logger_iter_interval
=
None
,
ckpt_save_time_interval
=
None
,
show_gpu_stat
=
False
,
cfg
=
None
):
accumulated_iter
=
start_iter
accumulated_iter
=
start_iter
# use for disable data augmentation hook
hook_config
=
cfg
.
get
(
'HOOK'
,
None
)
augment_disable_flag
=
False
with
tqdm
.
trange
(
start_epoch
,
total_epochs
,
desc
=
'epochs'
,
dynamic_ncols
=
True
,
leave
=
(
rank
==
0
))
as
tbar
:
with
tqdm
.
trange
(
start_epoch
,
total_epochs
,
desc
=
'epochs'
,
dynamic_ncols
=
True
,
leave
=
(
rank
==
0
))
as
tbar
:
total_it_each_epoch
=
len
(
train_loader
)
total_it_each_epoch
=
len
(
train_loader
)
if
merge_all_iters_to_one_epoch
:
if
merge_all_iters_to_one_epoch
:
...
@@ -170,6 +175,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
...
@@ -170,6 +175,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
cur_scheduler
=
lr_warmup_scheduler
cur_scheduler
=
lr_warmup_scheduler
else
:
else
:
cur_scheduler
=
lr_scheduler
cur_scheduler
=
lr_scheduler
augment_disable_flag
=
disable_augmentation_hook
(
hook_config
,
dataloader_iter
,
total_epochs
,
cur_epoch
,
cfg
,
augment_disable_flag
,
logger
)
accumulated_iter
=
train_one_epoch
(
accumulated_iter
=
train_one_epoch
(
model
,
optimizer
,
train_loader
,
model_func
,
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
=
cur_scheduler
,
lr_scheduler
=
cur_scheduler
,
...
@@ -245,3 +252,21 @@ def save_checkpoint(state, filename='checkpoint'):
...
@@ -245,3 +252,21 @@ def save_checkpoint(state, filename='checkpoint'):
torch
.
save
(
state
,
filename
,
_use_new_zipfile_serialization
=
False
)
torch
.
save
(
state
,
filename
,
_use_new_zipfile_serialization
=
False
)
else
:
else
:
torch
.
save
(
state
,
filename
)
torch
.
save
(
state
,
filename
)
def
disable_augmentation_hook
(
hook_config
,
dataloader
,
total_epochs
,
cur_epoch
,
cfg
,
flag
,
logger
):
"""
This hook turns off the data augmentation during training.
"""
if
hook_config
is
not
None
:
DisableAugmentationHook
=
hook_config
.
get
(
'DisableAugmentationHook'
,
None
)
if
DisableAugmentationHook
is
not
None
:
num_last_epochs
=
DisableAugmentationHook
.
NUM_LAST_EPOCHS
if
(
total_epochs
-
num_last_epochs
)
<=
cur_epoch
and
not
flag
:
DISABLE_AUG_LIST
=
DisableAugmentationHook
.
DISABLE_AUG_LIST
dataset_cfg
=
cfg
.
DATA_CONFIG
logger
.
info
(
f
'Disable augmentations:
{
DISABLE_AUG_LIST
}
'
)
dataset_cfg
.
DATA_AUGMENTOR
.
DISABLE_AUG_LIST
=
DISABLE_AUG_LIST
dataloader
.
_dataset
.
data_augmentor
.
disable_augmentation
(
dataset_cfg
.
DATA_AUGMENTOR
)
flag
=
True
return
flag
\ No newline at end of file
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment