"src/vscode:/vscode.git/clone" did not exist on "86aa747da9c99134f6527e4562014cbdd7ebaa72"
Unverified Commit ad310af0 authored by M. Tolga Cangöz's avatar M. Tolga Cangöz Committed by GitHub
Browse files

Fix EMA in train_text_to_image_sdxl.py (#7048)

* Fix typos
parent d603ccb6
......@@ -951,6 +951,9 @@ def main(args):
unet, optimizer, train_dataloader, lr_scheduler
)
if args.use_ema:
ema_unet.to(accelerator.device)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
......@@ -1126,6 +1129,8 @@ def main(args):
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment