Unverified Commit 031e129b authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

fix bug in training model by amp (#4874)



* fix bug in amp

* fix bug in training by amp

* support use gradient clipping when amp is enabled
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 8af692af
......@@ -30,22 +30,23 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=args.amp):
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if args.amp:
with torch.cuda.amp.autocast():
loss = criterion(output, target)
scaler.scale(loss).backward()
if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss = criterion(output, target)
loss.backward()
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
optimizer.step()
if model_ema and i % args.model_ema_steps == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment