Unverified Commit 25d927aa authored by Felix Blanke's avatar Felix Blanke Committed by GitHub
Browse files

Add `last_epoch` argument to `optimization.get_scheduler` (#2850)

Add last_epoch arg to optimization.get_scheduler.

Allows the specification of the index of the last epoch when
resuming training.
parent 663c6545
...@@ -242,6 +242,7 @@ def get_scheduler( ...@@ -242,6 +242,7 @@ def get_scheduler(
num_training_steps: Optional[int] = None, num_training_steps: Optional[int] = None,
num_cycles: int = 1, num_cycles: int = 1,
power: float = 1.0, power: float = 1.0,
last_epoch: int = -1,
): ):
""" """
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
...@@ -267,14 +268,14 @@ def get_scheduler( ...@@ -267,14 +268,14 @@ def get_scheduler(
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return schedule_func(optimizer) return schedule_func(optimizer, last_epoch=last_epoch)
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP: if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
# All other schedulers require `num_training_steps` # All other schedulers require `num_training_steps`
if num_training_steps is None: if num_training_steps is None:
...@@ -282,12 +283,22 @@ def get_scheduler( ...@@ -282,12 +283,22 @@ def get_scheduler(
if name == SchedulerType.COSINE_WITH_RESTARTS: if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
last_epoch=last_epoch,
) )
if name == SchedulerType.POLYNOMIAL: if name == SchedulerType.POLYNOMIAL:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
power=power,
last_epoch=last_epoch,
) )
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment