"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2afb2e0aacdd59ffa01d30d0da645ea1b5310a11"
Unverified Commit 8e46d97c authored by Christopher Beckham's avatar Christopher Beckham Committed by GitHub
Browse files

Add missing restore() EMA call in train SDXL script (#7599)



* Restore unet params back to normal from EMA when validation call is finished

* empty commit

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 7e808e76
...@@ -1252,6 +1252,10 @@ def main(args): ...@@ -1252,6 +1252,10 @@ def main(args):
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unwrap_model(unet) unet = unwrap_model(unet)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment