Unverified Commit 8e46d97c authored by Christopher Beckham's avatar Christopher Beckham Committed by GitHub
Browse files

Add missing restore() EMA call in train SDXL script (#7599)



* Restore unet params back to normal from EMA when validation call is finished

* empty commit

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 7e808e76
......@@ -1252,6 +1252,10 @@ def main(args):
del pipeline
torch.cuda.empty_cache()
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment