Unverified Commit 14b97549 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[train_unconditional] fix applying clip_grad_norm_ (#721)

fix clip_grad_norm_
parent 6b221920
......@@ -143,7 +143,8 @@ def main(args):
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
if args.use_ema:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment