"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f57b27d2ad52b7d3958a83d9ec4ce4cb6d42f615"
Unverified Commit fe8d1302 authored by Charbel Abi Daher's avatar Charbel Abi Daher Committed by GitHub
Browse files

Added passing parameters to "reduce_lr_on_plateau" scheduler (#27860)

parent 56be5e80
...@@ -53,19 +53,22 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): ...@@ -53,19 +53,22 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch) return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
def get_reduce_on_plateau_schedule(optimizer: Optimizer): def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
""" """
Create a schedule with a constant learning rate that decreases when a metric has stopped improving. Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
Args: Args:
optimizer ([`~torch.optim.Optimizer`]): optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate. The optimizer for which to schedule the learning rate.
kwargs (`dict`, *optional*):
Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
for possible parameters.
Return: Return:
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
""" """
return ReduceLROnPlateau(optimizer) return ReduceLROnPlateau(optimizer, **kwargs)
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
...@@ -359,9 +362,15 @@ def get_scheduler( ...@@ -359,9 +362,15 @@ def get_scheduler(
""" """
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU: if name == SchedulerType.CONSTANT:
return schedule_func(optimizer) return schedule_func(optimizer)
if scheduler_specific_kwargs is None:
scheduler_specific_kwargs = {}
if name == SchedulerType.REDUCE_ON_PLATEAU:
return schedule_func(optimizer, **scheduler_specific_kwargs)
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
...@@ -376,9 +385,6 @@ def get_scheduler( ...@@ -376,9 +385,6 @@ def get_scheduler(
if num_training_steps is None: if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if scheduler_specific_kwargs is None:
scheduler_specific_kwargs = {}
return schedule_func( return schedule_func(
optimizer, optimizer,
num_warmup_steps=num_warmup_steps, num_warmup_steps=num_warmup_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment