"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "66e50d4e248a32ef8f8698cf3e6f0e1040f74cfc"
Unverified Commit d87cc159 authored by Emil Bogomolov's avatar Emil Bogomolov Committed by GitHub
Browse files

expose polynomial:power and cosine_with_restarts:num_cycles params (#1737)



* expose polynomial:power and cosine_with_restarts:num_cycles using get_scheduler func, add it to train_dreambooth.py

* fix formatting

* fix style

* Update src/diffusers/optimization.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent e29dc972
...@@ -204,6 +204,13 @@ def parse_args(input_args=None): ...@@ -204,6 +204,13 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
) )
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument( parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
) )
...@@ -588,6 +595,8 @@ def main(args): ...@@ -588,6 +595,8 @@ def main(args):
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
) )
if args.train_text_encoder: if args.train_text_encoder:
......
...@@ -121,9 +121,9 @@ def get_cosine_schedule_with_warmup( ...@@ -121,9 +121,9 @@ def get_cosine_schedule_with_warmup(
The number of steps for the warmup phase. The number of steps for the warmup phase.
num_training_steps (`int`): num_training_steps (`int`):
The total number of training steps. The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5): num_periods (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 The number of periods of the cosine function in a schedule (the default is to just decrease from the max
following a half-cosine). value to 0 following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1): last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training. The index of the last epoch when resuming training.
...@@ -240,6 +240,8 @@ def get_scheduler( ...@@ -240,6 +240,8 @@ def get_scheduler(
optimizer: Optimizer, optimizer: Optimizer,
num_warmup_steps: Optional[int] = None, num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None, num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
): ):
""" """
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
...@@ -255,6 +257,12 @@ def get_scheduler( ...@@ -255,6 +257,12 @@ def get_scheduler(
num_training_steps (`int``, *optional*): num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it. optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
""" """
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
...@@ -272,4 +280,14 @@ def get_scheduler( ...@@ -272,4 +280,14 @@ def get_scheduler(
if num_training_steps is None: if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment