"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "32fa459ed7c812c79e847145004061f21b7ac0d9"
Unverified Commit 3511a962 authored by Genius Patrick's avatar Genius Patrick Committed by GitHub
Browse files

fix(training): lr scheduler doesn't work properly in distributed scenarios (#8312)

parent 42cae93b
...@@ -697,17 +697,22 @@ def main(): ...@@ -697,17 +697,22 @@ def main():
) )
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
overrode_max_train_steps = True num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=args.max_train_steps * accelerator.num_processes, num_training_steps=num_training_steps_for_scheduler,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
...@@ -717,8 +722,14 @@ def main(): ...@@ -717,8 +722,14 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment