Unverified Commit 14f4af8f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[dreambooth] fix applying clip_grad_norm_ (#686)

fix applying clip grad norm
parent 2558977b
......@@ -566,7 +566,8 @@ def main():
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment