Unverified Commit db7b7bd9 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Train unconditional] Unwrap model before EMA (#1469)

parent 6a0a3123
...@@ -320,7 +320,12 @@ def main(args): ...@@ -320,7 +320,12 @@ def main(args):
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) ema_model = EMAModel(
accelerator.unwrap_model(model),
inv_gamma=args.ema_inv_gamma,
power=args.ema_power,
max_value=args.ema_max_decay,
)
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment