Commit f047d659 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds step consine learning rate.

PiperOrigin-RevId: 388322760
parent 371b1da4
...@@ -211,6 +211,44 @@ class PowerDecayWithOffsetLrConfig(base_config.Config): ...@@ -211,6 +211,44 @@ class PowerDecayWithOffsetLrConfig(base_config.Config):
pre_offset_learning_rate: float = 1.0e6 pre_offset_learning_rate: float = 1.0e6
@dataclasses.dataclass
class StepCosineLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay.
This class is a container for the piecewise cosine learning rate scheduling
configs. It will configure an instance of StepConsineDecayWithOffset keras
learning rate schedule.
```python
boundaries: [100000, 110000]
values: [1.0, 0.5]
lr_decayed_fn = (
lr_schedule.StepConsineDecayWithOffset(
boundaries,
values))
```
from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
Attributes:
name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
boundaries: A list of ints of strictly increasing entries. Defaults to None.
values: A list of floats that specifies the values for the intervals defined
by `boundaries`. It should have one more element than `boundaries`.
The learning rate is computed as follows:
[0, boundaries[0]] -> cosine from values[0] to values[1]
[boundaries[0], boundaries[1]] -> values[1] to values[2]
...
[boundaries[n-1], boundaries[n]] -> values[n] to values[n+1]
[boundaries[n], end] -> values[n+1] to 0.
offset: An int. The offset applied to steps. Defaults to 0.
"""
name: str = 'StepConsineDecayWithOffset'
boundaries: Optional[List[int]] = None
values: Optional[List[float]] = None
offset: int = 0
@dataclasses.dataclass @dataclasses.dataclass
class LinearWarmupConfig(base_config.Config): class LinearWarmupConfig(base_config.Config):
"""Configuration for linear warmup schedule config. """Configuration for linear warmup schedule config.
......
...@@ -70,6 +70,7 @@ class LrConfig(oneof.OneOfConfig): ...@@ -70,6 +70,7 @@ class LrConfig(oneof.OneOfConfig):
power_linear: learning rate config of step^power followed by power_linear: learning rate config of step^power followed by
step^power*linear. step^power*linear.
power_with_offset: power decay with a step offset. power_with_offset: power decay with a step offset.
step_cosine_with_offset: Step cosine with a step offset.
""" """
type: Optional[str] = None type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig() constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
...@@ -82,6 +83,8 @@ class LrConfig(oneof.OneOfConfig): ...@@ -82,6 +83,8 @@ class LrConfig(oneof.OneOfConfig):
lr_cfg.PowerAndLinearDecayLrConfig()) lr_cfg.PowerAndLinearDecayLrConfig())
power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = ( power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = (
lr_cfg.PowerDecayWithOffsetLrConfig()) lr_cfg.PowerDecayWithOffsetLrConfig())
step_cosine_with_offset: lr_cfg.StepCosineLrConfig = (
lr_cfg.StepCosineLrConfig())
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Learning rate schedule classes.""" """Learning rate schedule classes."""
import math
from typing import Mapping, Any, Union, Optional from typing import Mapping, Any, Union, Optional
import tensorflow as tf import tensorflow as tf
...@@ -383,3 +384,113 @@ class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -383,3 +384,113 @@ class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule):
"pre_offset_learning_rate": self._pre_offset_lr, "pre_offset_learning_rate": self._pre_offset_lr,
"name": self._name, "name": self._name,
} }
class StepConsineDecayWithOffset(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Stepwise cosine learning rate decay with offset.
Learning rate is equivalent to one or more consine decay(s) starting and
ending at each interval.
ExampleL
```python
boundaries: [100000, 110000]
values: [1.0, 0.5]
lr_decayed_fn = (
lr_schedule.StepConsineDecayWithOffset(
boundaries,
values))
```
from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
"""
def __init__(self,
boundaries,
values,
offset: int = 0,
name: str = "StepConsineDecayWithOffset"):
"""Initialize configuration of the learning rate schedule.
Args:
boundaries: A list of `Tensor`s or `int`s with strictly
increasing entries, and with all elements having the same type as the
optimizer step.
values: A list of `Tensor`s or `float`s that specifies the
values for the intervals defined by `boundaries`. It should have one
more element than `boundaries`, and all elements should have the same
type.
offset: The offset when computing the power decay.
name: Optional, name of learning rate schedule.
"""
super().__init__()
self.values = values
self.boundaries = boundaries
self.offset = offset
self.name = name
if len(self.values) < 1:
raise ValueError(f"Expect non empty {self.values}")
if len(self.boundaries) != len(self.values):
raise ValueError(
"Boundaries length is equal to learning rate levels length"
f"{len(self.boundaries)} != {len(self.values)}")
self.total_steps = (
[boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1)
] + [0])
def __call__(self, global_step):
with tf.name_scope(self.name or "StepConsineDecayWithOffset"):
global_step = tf.cast(global_step - self.offset, tf.float32)
lr_levels = self.values
lr_steps = self.boundaries
level_total_steps = self.total_steps
num_levels = len(lr_levels)
init_lr = lr_levels[0]
next_init_lr = lr_levels[1] if num_levels > 1 else 0.
init_total_steps = level_total_steps[0]
cosine_learning_rate = ((init_lr - next_init_lr) * (tf.cos(
tf.constant(math.pi) * (global_step) /
(init_total_steps)) + 1.0) / 2.0 + next_init_lr)
learning_rate = cosine_learning_rate
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
cosine_learning_rate)
tf.compat.v1.logging.info("DEBUG lr %r next lr %r inittotalstep %r",
init_lr, next_init_lr, init_total_steps)
for i in range(1, num_levels):
next_init_lr = lr_levels[i]
next_start_step = lr_steps[i]
next_total_steps = level_total_steps[i]
next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0.
tf.compat.v1.logging.info(
"DEBUG step %r nilr %r nss %r nts %r nnilr %r", global_step,
next_init_lr, next_start_step, next_total_steps, next_next_init_lr)
next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) *
(tf.cos(
tf.constant(math.pi) *
(global_step - next_start_step) /
(next_total_steps)) + 1.0) / 2.0 +
next_next_init_lr)
learning_rate = tf.where(global_step >= next_start_step,
next_cosine_learning_rate, learning_rate)
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
next_cosine_learning_rate)
return learning_rate
def get_config(self):
return {
"boundaries": self.boundaries,
"values": self.values,
"offset": self.offset,
"name": self.name
}
...@@ -47,6 +47,7 @@ LR_CLS = { ...@@ -47,6 +47,7 @@ LR_CLS = {
'power': lr_schedule.DirectPowerDecay, 'power': lr_schedule.DirectPowerDecay,
'power_linear': lr_schedule.PowerAndLinearDecay, 'power_linear': lr_schedule.PowerAndLinearDecay,
'power_with_offset': lr_schedule.PowerDecayWithOffset, 'power_with_offset': lr_schedule.PowerDecayWithOffset,
'step_cosine_with_offset': lr_schedule.StepConsineDecayWithOffset,
} }
WARMUP_CLS = { WARMUP_CLS = {
......
...@@ -394,5 +394,38 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -394,5 +394,38 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values: for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value) self.assertAlmostEqual(lr(step).numpy(), value)
def test_step_cosine_lr_schedule_with_warmup(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'step_cosine_with_offset',
'step_cosine_with_offset': {
'values': (0.0001, 0.00005),
'boundaries': (0, 500000),
'offset': 10000,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 10000,
'warmup_learning_rate': 0.0
}
}
}
expected_lr_step_values = [[0, 0.0], [5000, 1e-4/2.0], [10000, 1e-4],
[20000, 9.994863e-05], [499999, 5e-05]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment