"...asr/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "d62875cc67f0ecae75c6edeffa1c74178308e034"
Unverified Commit 1de53bef authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Simplify the gradient clipping code. (#4896)

parent f676f940
...@@ -40,13 +40,13 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg ...@@ -40,13 +40,13 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
if args.clip_grad_norm is not None: if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping # we should unscale the gradients of optimizer's assigned params if do gradient clipping
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else: else:
loss.backward() loss.backward()
if args.clip_grad_norm is not None: if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step() optimizer.step()
if model_ema and i % args.model_ema_steps == 0: if model_ema and i % args.model_ema_steps == 0:
......
...@@ -409,11 +409,3 @@ def reduce_across_processes(val): ...@@ -409,11 +409,3 @@ def reduce_across_processes(val):
dist.barrier() dist.barrier()
dist.all_reduce(t) dist.all_reduce(t)
return t return t
def get_optimizer_params(optimizer):
"""Generator to iterate over all parameters in the optimizer param_groups."""
for group in optimizer.param_groups:
for p in group["params"]:
yield p
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment