"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a9cb08af398c9fe06d2d62bd12942458d5dba151"
Unverified Commit 0dec414d authored by Kenneth Gerald Hamilton's avatar Kenneth Gerald Hamilton Committed by GitHub
Browse files

[train_dreambooth_lora_sdxl.py] Fix the LR Schedulers when num_train_epochs is...


[train_dreambooth_lora_sdxl.py] Fix the LR Schedulers when num_train_epochs is passed in a distributed training env (#11240)
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 44eeba07
...@@ -1523,17 +1523,22 @@ def main(args): ...@@ -1523,17 +1523,22 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
overrode_max_train_steps = True num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=args.max_train_steps * accelerator.num_processes, num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
...@@ -1550,7 +1555,14 @@ def main(args): ...@@ -1550,7 +1555,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment