Unverified Commit 02ac3e17 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Support multi-modal 3D detection on NuScenes #1339

Add support for multi-modal NuScenes Detection
parents ad9c25c0 fcfa0773
...@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
data_timer = time.time() data_timer = time.time()
cur_data_time = data_timer - end cur_data_time = data_timer - end
lr_scheduler.step(accumulated_iter) lr_scheduler.step(accumulated_iter, cur_epoch)
try: try:
cur_lr = float(optimizer.lr) cur_lr = float(optimizer.lr)
...@@ -151,8 +151,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_ ...@@ -151,8 +151,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None, start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None,
lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50, lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50,
merge_all_iters_to_one_epoch=False, use_amp=False, merge_all_iters_to_one_epoch=False, use_amp=False,
use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False): use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False, cfg=None):
accumulated_iter = start_iter accumulated_iter = start_iter
# use for disable data augmentation hook
hook_config = cfg.get('HOOK', None)
augment_disable_flag = False
with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar: with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
total_it_each_epoch = len(train_loader) total_it_each_epoch = len(train_loader)
if merge_all_iters_to_one_epoch: if merge_all_iters_to_one_epoch:
...@@ -170,6 +175,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_ ...@@ -170,6 +175,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
cur_scheduler = lr_warmup_scheduler cur_scheduler = lr_warmup_scheduler
else: else:
cur_scheduler = lr_scheduler cur_scheduler = lr_scheduler
augment_disable_flag = disable_augmentation_hook(hook_config, dataloader_iter, total_epochs, cur_epoch, cfg, augment_disable_flag, logger)
accumulated_iter = train_one_epoch( accumulated_iter = train_one_epoch(
model, optimizer, train_loader, model_func, model, optimizer, train_loader, model_func,
lr_scheduler=cur_scheduler, lr_scheduler=cur_scheduler,
...@@ -245,3 +252,21 @@ def save_checkpoint(state, filename='checkpoint'): ...@@ -245,3 +252,21 @@ def save_checkpoint(state, filename='checkpoint'):
torch.save(state, filename, _use_new_zipfile_serialization=False) torch.save(state, filename, _use_new_zipfile_serialization=False)
else: else:
torch.save(state, filename) torch.save(state, filename)
def disable_augmentation_hook(hook_config, dataloader, total_epochs, cur_epoch, cfg, flag, logger):
"""
This hook turns off the data augmentation during training.
"""
if hook_config is not None:
DisableAugmentationHook = hook_config.get('DisableAugmentationHook', None)
if DisableAugmentationHook is not None:
num_last_epochs = DisableAugmentationHook.NUM_LAST_EPOCHS
if (total_epochs - num_last_epochs) <= cur_epoch and not flag:
DISABLE_AUG_LIST = DisableAugmentationHook.DISABLE_AUG_LIST
dataset_cfg=cfg.DATA_CONFIG
logger.info(f'Disable augmentations: {DISABLE_AUG_LIST}')
dataset_cfg.DATA_AUGMENTOR.DISABLE_AUG_LIST = DISABLE_AUG_LIST
dataloader._dataset.data_augmentor.disable_augmentation(dataset_cfg.DATA_AUGMENTOR)
flag = True
return flag
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment