Unverified Commit 028ed72e authored by jihan.yang's avatar jihan.yang Committed by GitHub
Browse files

Support amp/fp16 training (#1192)



* support amp
Co-authored-by: default avataryukang <yukangchen@cse.cuhk.edu.hk>
parent e255e9a3
...@@ -48,6 +48,7 @@ def parse_config(): ...@@ -48,6 +48,7 @@ def parse_config():
parser.add_argument('--logger_iter_interval', type=int, default=50, help='') parser.add_argument('--logger_iter_interval', type=int, default=50, help='')
parser.add_argument('--ckpt_save_time_interval', type=int, default=300, help='in terms of seconds') parser.add_argument('--ckpt_save_time_interval', type=int, default=300, help='in terms of seconds')
parser.add_argument('--wo_gpu_stat', action='store_true', help='') parser.add_argument('--wo_gpu_stat', action='store_true', help='')
parser.add_argument('--use_amp', action='store_true', help='use mix precision training')
args = parser.parse_args() args = parser.parse_args()
...@@ -56,6 +57,8 @@ def parse_config(): ...@@ -56,6 +57,8 @@ def parse_config():
cfg.TAG = Path(args.cfg_file).stem cfg.TAG = Path(args.cfg_file).stem
cfg.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1]) # remove 'cfgs' and 'xxxx.yaml' cfg.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1]) # remove 'cfgs' and 'xxxx.yaml'
args.use_amp = args.use_amp or cfg.OPTIMIZATION.get('USE_AMP', False)
if args.set_cfgs is not None: if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs, cfg) cfg_from_list(args.set_cfgs, cfg)
...@@ -187,7 +190,8 @@ def main(): ...@@ -187,7 +190,8 @@ def main():
logger_iter_interval=args.logger_iter_interval, logger_iter_interval=args.logger_iter_interval,
ckpt_save_time_interval=args.ckpt_save_time_interval, ckpt_save_time_interval=args.ckpt_save_time_interval,
use_logger_to_record=not args.use_tqdm_to_record, use_logger_to_record=not args.use_tqdm_to_record,
show_gpu_stat=not args.wo_gpu_stat show_gpu_stat=not args.wo_gpu_stat,
use_amp=args.use_amp
) )
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory: if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
......
import glob
import os import os
import torch import torch
import tqdm import tqdm
import time import time
import glob
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from pcdet.utils import common_utils, commu_utils from pcdet.utils import common_utils, commu_utils
...@@ -11,13 +11,15 @@ from pcdet.utils import common_utils, commu_utils ...@@ -11,13 +11,15 @@ from pcdet.utils import common_utils, commu_utils
def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg, def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False, rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False,
use_logger_to_record=False, logger=None, logger_iter_interval=50, cur_epoch=None, use_logger_to_record=False, logger=None, logger_iter_interval=50, cur_epoch=None,
total_epochs=None, ckpt_save_dir=None, ckpt_save_time_interval=300, show_gpu_stat=False): total_epochs=None, ckpt_save_dir=None, ckpt_save_time_interval=300, show_gpu_stat=False, use_amp=False):
if total_it_each_epoch == len(train_loader): if total_it_each_epoch == len(train_loader):
dataloader_iter = iter(train_loader) dataloader_iter = iter(train_loader)
ckpt_save_cnt = 1 ckpt_save_cnt = 1
start_it = accumulated_iter % total_it_each_epoch start_it = accumulated_iter % total_it_each_epoch
scaler = torch.cuda.amp.GradScaler(enabled=use_amp, init_scale=optim_cfg.get('LOSS_SCALE_FP16', 2.0**16))
if rank == 0: if rank == 0:
pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True) pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)
data_time = common_utils.AverageMeter() data_time = common_utils.AverageMeter()
...@@ -49,11 +51,14 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -49,11 +51,14 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=use_amp):
loss, tb_dict, disp_dict = model_func(model, batch) loss, tb_dict, disp_dict = model_func(model, batch)
loss.backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP) clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
optimizer.step() scaler.step(optimizer)
scaler.update()
accumulated_iter += 1 accumulated_iter += 1
...@@ -127,7 +132,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -127,7 +132,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg, def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg,
start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None, start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None,
lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50, lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50,
merge_all_iters_to_one_epoch=False, merge_all_iters_to_one_epoch=False, use_amp=False,
use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False): use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False):
accumulated_iter = start_iter accumulated_iter = start_iter
with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar: with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
...@@ -160,7 +165,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_ ...@@ -160,7 +165,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
use_logger_to_record=use_logger_to_record, use_logger_to_record=use_logger_to_record,
logger=logger, logger_iter_interval=logger_iter_interval, logger=logger, logger_iter_interval=logger_iter_interval,
ckpt_save_dir=ckpt_save_dir, ckpt_save_time_interval=ckpt_save_time_interval, ckpt_save_dir=ckpt_save_dir, ckpt_save_time_interval=ckpt_save_time_interval,
show_gpu_stat=show_gpu_stat show_gpu_stat=show_gpu_stat,
use_amp=use_amp
) )
# save trained model # save trained model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment