Unverified Commit ece55227 authored by Ruoxi's avatar Ruoxi Committed by GitHub
Browse files

Multiply lr scheduler steps by `num_processes`. (#3983)

* Multiply lr scheduler steps by `num_processes`.

* Stop multiplying steps by gradient accumulation.
parent 92a57a8e
...@@ -897,8 +897,8 @@ def main(args): ...@@ -897,8 +897,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
......
...@@ -1007,8 +1007,8 @@ def main(args): ...@@ -1007,8 +1007,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -1075,8 +1075,8 @@ def main(args): ...@@ -1075,8 +1075,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
......
...@@ -1039,8 +1039,8 @@ def main(args): ...@@ -1039,8 +1039,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
......
...@@ -690,8 +690,8 @@ def main(): ...@@ -690,8 +690,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -600,8 +600,8 @@ def main(): ...@@ -600,8 +600,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
if args.train_text_encoder: if args.train_text_encoder:
......
...@@ -644,8 +644,8 @@ def main(): ...@@ -644,8 +644,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -481,8 +481,8 @@ def main(): ...@@ -481,8 +481,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
......
...@@ -588,8 +588,8 @@ def main(): ...@@ -588,8 +588,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
if not train_unet: if not train_unet:
......
...@@ -701,8 +701,8 @@ def main(): ...@@ -701,8 +701,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -690,8 +690,8 @@ def main(): ...@@ -690,8 +690,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -970,8 +970,8 @@ def main(args): ...@@ -970,8 +970,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
......
...@@ -732,8 +732,8 @@ def main(): ...@@ -732,8 +732,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -741,8 +741,8 @@ def main(): ...@@ -741,8 +741,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -819,8 +819,8 @@ def main(): ...@@ -819,8 +819,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -662,8 +662,8 @@ def main(): ...@@ -662,8 +662,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
...@@ -737,9 +737,9 @@ def main(): ...@@ -737,9 +737,9 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles * args.gradient_accumulation_steps, num_cycles=args.lr_num_cycles,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment