Commit f03e5f9d authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 347634117
parent 92cc9dbc
......@@ -146,6 +146,27 @@ class DirectPowerLrConfig(base_config.Config):
power: float = -0.5
@dataclasses.dataclass
class PowerAndLinearDecayLrConfig(base_config.Config):
"""Configuration for DirectPower learning rate decay.
This class configures a schedule following follows lr * (step)^power for the
first total_decay_steps * (1 - linear_decay_fraction) steps, and follows
lr * (step)^power * (total_decay_steps - step) / (total_decay_steps *
linear_decay_fraction) for the rest of the steps.
Attributes:
name: The name of the learning rate schedule. Defaults to DirectPowerDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
"""
name: str = 'PowerAndLinearDecay'
initial_learning_rate: Optional[float] = None
total_decay_steps: Optional[int] = None
power: float = -0.5
linear_decay_fraction: float = 0.1
@dataclasses.dataclass
class LinearWarmupConfig(base_config.Config):
"""Configuration for linear warmup schedule config.
......
......@@ -61,6 +61,8 @@ class LrConfig(oneof.OneOfConfig):
polynomial: polynomial learning rate config.
cosine: cosine learning rate config.
power: step^power learning rate config.
power_linear: learning rate config of step^power followed by
step^power*linear.
"""
type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
......@@ -69,6 +71,8 @@ class LrConfig(oneof.OneOfConfig):
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
cosine: lr_cfg.CosineLrConfig = lr_cfg.CosineLrConfig()
power: lr_cfg.DirectPowerLrConfig = lr_cfg.DirectPowerLrConfig()
power_linear: lr_cfg.PowerAndLinearDecayLrConfig = (
lr_cfg.PowerAndLinearDecayLrConfig())
@dataclasses.dataclass
......
......@@ -188,3 +188,58 @@ class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"power": self._power,
"name": self._name,
}
class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule with multiplied by linear decay at the end.
follows lr * (step)^power for the first total_decay_steps *
(1 - linear_decay_fraction) steps, and follows lr * (step)^power *
(total_decay_steps - step) / (total_decay_steps * linear_decay_fraction)
for the rest of the steps.
"""
def __init__(self,
initial_learning_rate: float,
total_decay_steps: int,
power: float = 1.0,
linear_decay_fraction: float = 0.1,
name: str = "PowerAndLinearDecay"):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: A float, the initial learning rate.
total_decay_steps: The total number of steps for power + linear decay.
power: A float, the number of steps required for linear warmup.
linear_decay_fraction: A float, in the last `linear_decay_fraction` steps,
the learning rate will be multiplied by a linear decay.
name: Optional, name of warmup schedule.
"""
super(PowerAndLinearDecay, self).__init__()
self._initial_learning_rate = initial_learning_rate
self._total_decay_steps = total_decay_steps
self._power = power
self._linear_decay_fraction = linear_decay_fraction
self._name = name
def __call__(self, step):
with tf.name_scope(self._name or "PowerAndLinearDecay"):
step = tf.cast(step, tf.float32)
learning_rate = self._initial_learning_rate
learning_rate *= tf.math.pow(step, self._power)
if self._linear_decay_fraction > 0:
learning_rate *= tf.minimum(
1.0, (self._total_decay_steps - step) /
(self._total_decay_steps * self._linear_decay_fraction))
learning_rate = tf.maximum(0.0, learning_rate)
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
"initial_learning_rate": self._initial_learning_rate,
"total_decay_steps": self._total_decay_steps,
"power": self._power,
"linear_decay_fraction": self._linear_decay_fraction,
"name": self._name,
}
......@@ -40,6 +40,7 @@ LR_CLS = {
'exponential': tf.keras.optimizers.schedules.ExponentialDecay,
'cosine': tf.keras.experimental.CosineDecay,
'power': lr_schedule.DirectPowerDecay,
'power_linear': lr_schedule.PowerAndLinearDecay,
}
WARMUP_CLS = {
......
......@@ -340,6 +340,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
def test_power_linear_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'power_linear',
'power_linear': {
'initial_learning_rate': 1.0,
'power': -1.0,
'linear_decay_fraction': 0.5,
'total_decay_steps': 100,
}
}
}
expected_lr_step_values = [[1, 1.0], [40, 1. / 40.], [60, 1. / 60. * 0.8]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment