Unverified Commit f861cde1 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[train_unconditional] fix LR scheduler init (#2010)

fix lr scheduler
parent b2ea8a84
...@@ -338,8 +338,8 @@ def main(args): ...@@ -338,8 +338,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, num_training_steps=(len(train_dataloader) * args.num_epochs),
) )
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment