Unverified Commit 44eeba07 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Flux LoRAs] fix lr scheduler bug in distributed scenarios (#11242)



* add fix

* add fix

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 5873377a
...@@ -1915,17 +1915,22 @@ def main(args): ...@@ -1915,17 +1915,22 @@ def main(args):
free_memory() free_memory()
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
overrode_max_train_steps = True num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=args.max_train_steps * accelerator.num_processes, num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
...@@ -1949,7 +1954,6 @@ def main(args): ...@@ -1949,7 +1954,6 @@ def main(args):
lr_scheduler, lr_scheduler,
) )
else: else:
print("I SHOULD BE HERE")
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
) )
...@@ -1961,8 +1965,14 @@ def main(args): ...@@ -1961,8 +1965,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
......
...@@ -1407,17 +1407,22 @@ def main(args): ...@@ -1407,17 +1407,22 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
overrode_max_train_steps = True num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=args.max_train_steps * accelerator.num_processes, num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
...@@ -1444,8 +1449,14 @@ def main(args): ...@@ -1444,8 +1449,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
......
...@@ -1524,17 +1524,22 @@ def main(args): ...@@ -1524,17 +1524,22 @@ def main(args):
free_memory() free_memory()
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
overrode_max_train_steps = True num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=args.max_train_steps * accelerator.num_processes, num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
...@@ -1561,8 +1566,14 @@ def main(args): ...@@ -1561,8 +1566,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment