Unverified Commit de45af4a authored by 0x1355's avatar 0x1355 Committed by GitHub
Browse files

Allow setting num_cycles for cosine_with_restarts lr scheduler (#3606)

Expose num_cycles kwarg of get_schedule() through args.lr_num_cycles.
parent b95cbdf6
......@@ -285,6 +285,12 @@ def parse_args():
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
......@@ -739,6 +745,7 @@ def main():
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_num_cycles * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment