"googlemock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "b545089f511753905c0911d545ed2d25c867f563"
Unverified Commit 028ed72e authored by jihan.yang's avatar jihan.yang Committed by GitHub
Browse files

Support amp/fp16 training (#1192)



* support amp
Co-authored-by: default avataryukang <yukangchen@cse.cuhk.edu.hk>
parent e255e9a3
......@@ -48,6 +48,7 @@ def parse_config():
parser.add_argument('--logger_iter_interval', type=int, default=50, help='')
parser.add_argument('--ckpt_save_time_interval', type=int, default=300, help='in terms of seconds')
parser.add_argument('--wo_gpu_stat', action='store_true', help='')
parser.add_argument('--use_amp', action='store_true', help='use mix precision training')
args = parser.parse_args()
......@@ -56,6 +57,8 @@ def parse_config():
cfg.TAG = Path(args.cfg_file).stem
cfg.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1]) # remove 'cfgs' and 'xxxx.yaml'
args.use_amp = args.use_amp or cfg.OPTIMIZATION.get('USE_AMP', False)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs, cfg)
......@@ -187,7 +190,8 @@ def main():
logger_iter_interval=args.logger_iter_interval,
ckpt_save_time_interval=args.ckpt_save_time_interval,
use_logger_to_record=not args.use_tqdm_to_record,
show_gpu_stat=not args.wo_gpu_stat
show_gpu_stat=not args.wo_gpu_stat,
use_amp=args.use_amp
)
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
......
import glob
import os
import torch
import tqdm
import time
import glob
from torch.nn.utils import clip_grad_norm_
from pcdet.utils import common_utils, commu_utils
......@@ -11,13 +11,15 @@ from pcdet.utils import common_utils, commu_utils
def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False,
use_logger_to_record=False, logger=None, logger_iter_interval=50, cur_epoch=None,
total_epochs=None, ckpt_save_dir=None, ckpt_save_time_interval=300, show_gpu_stat=False):
total_epochs=None, ckpt_save_dir=None, ckpt_save_time_interval=300, show_gpu_stat=False, use_amp=False):
if total_it_each_epoch == len(train_loader):
dataloader_iter = iter(train_loader)
ckpt_save_cnt = 1
start_it = accumulated_iter % total_it_each_epoch
scaler = torch.cuda.amp.GradScaler(enabled=use_amp, init_scale=optim_cfg.get('LOSS_SCALE_FP16', 2.0**16))
if rank == 0:
pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)
data_time = common_utils.AverageMeter()
......@@ -49,11 +51,14 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=use_amp):
loss, tb_dict, disp_dict = model_func(model, batch)
loss.backward()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
optimizer.step()
scaler.step(optimizer)
scaler.update()
accumulated_iter += 1
......@@ -127,7 +132,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg,
start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None,
lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50,
merge_all_iters_to_one_epoch=False,
merge_all_iters_to_one_epoch=False, use_amp=False,
use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False):
accumulated_iter = start_iter
with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
......@@ -160,7 +165,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
use_logger_to_record=use_logger_to_record,
logger=logger, logger_iter_interval=logger_iter_interval,
ckpt_save_dir=ckpt_save_dir, ckpt_save_time_interval=ckpt_save_time_interval,
show_gpu_stat=show_gpu_stat
show_gpu_stat=show_gpu_stat,
use_amp=use_amp
)
# save trained model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment