"vscode:/vscode.git/clone" did not exist on "34426a98e9ecc9f4b90d2b29ea243e095e9eb19e"
Unverified Commit 8e46d97c authored by Christopher Beckham's avatar Christopher Beckham Committed by GitHub
Browse files

Add missing restore() EMA call in train SDXL script (#7599)



* Restore unet params back to normal from EMA when validation call is finished

* empty commit

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 7e808e76
...@@ -1252,6 +1252,10 @@ def main(args): ...@@ -1252,6 +1252,10 @@ def main(args):
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unwrap_model(unet) unet = unwrap_model(unet)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment