Unverified Commit ca980fd0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Examples] Make sure EMA works with any device (#2382)

* Fix EMA

* up

* update
parent a60f5555
...@@ -438,6 +438,7 @@ def main(): ...@@ -438,6 +438,7 @@ def main():
if args.use_ema: if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
ema_unet.load_state_dict(load_model.state_dict()) ema_unet.load_state_dict(load_model.state_dict())
ema_unet.to(accelerator.device)
del load_model del load_model
for i in range(len(models)): for i in range(len(models)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment