"tests/nn/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "b89365e65702ee5bd51f8138bfa6246cda9b6b68"
Commit 249dcac3 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 360489888
parent dee618a5
...@@ -166,6 +166,29 @@ class PowerAndLinearDecayLrConfig(base_config.Config): ...@@ -166,6 +166,29 @@ class PowerAndLinearDecayLrConfig(base_config.Config):
linear_decay_fraction: float = 0.1 linear_decay_fraction: float = 0.1
@dataclasses.dataclass
class PowerDecayWithOffsetLrConfig(base_config.Config):
"""Configuration for power learning rate decay with step offset.
Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
Otherwise, learning rate equals to lr * (step - offset)^power.
Attributes:
name: The name of the learning rate schedule.
Defaults to PowerDecayWithOffset.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
offset: An integer. Power decay happens after `offset` steps.
pre_offset_learning_rate: A float. The constant learning rate before
`offset` steps.
"""
name: str = 'PowerDecayWithOffset'
initial_learning_rate: Optional[float] = None
power: float = -0.5
offset: int = 0
pre_offset_learning_rate: float = 1.0e6
@dataclasses.dataclass @dataclasses.dataclass
class LinearWarmupConfig(base_config.Config): class LinearWarmupConfig(base_config.Config):
"""Configuration for linear warmup schedule config. """Configuration for linear warmup schedule config.
......
...@@ -62,6 +62,7 @@ class LrConfig(oneof.OneOfConfig): ...@@ -62,6 +62,7 @@ class LrConfig(oneof.OneOfConfig):
power: step^power learning rate config. power: step^power learning rate config.
power_linear: learning rate config of step^power followed by power_linear: learning rate config of step^power followed by
step^power*linear. step^power*linear.
power_with_offset: power decay with a step offset.
""" """
type: Optional[str] = None type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig() constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
...@@ -72,6 +73,8 @@ class LrConfig(oneof.OneOfConfig): ...@@ -72,6 +73,8 @@ class LrConfig(oneof.OneOfConfig):
power: lr_cfg.DirectPowerLrConfig = lr_cfg.DirectPowerLrConfig() power: lr_cfg.DirectPowerLrConfig = lr_cfg.DirectPowerLrConfig()
power_linear: lr_cfg.PowerAndLinearDecayLrConfig = ( power_linear: lr_cfg.PowerAndLinearDecayLrConfig = (
lr_cfg.PowerAndLinearDecayLrConfig()) lr_cfg.PowerAndLinearDecayLrConfig())
power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = (
lr_cfg.PowerDecayWithOffsetLrConfig())
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -40,9 +40,8 @@ class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -40,9 +40,8 @@ class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
Args: Args:
after_warmup_lr_sched: tf.keras.optimizers.schedules after_warmup_lr_sched: tf.keras.optimizers.schedules
.LearningRateSchedule or a constant. .LearningRateSchedule or a constant.
warmup_steps: int. number of the warmup steps. warmup_steps: Number of the warmup steps.
warmup_learning_rate: floating point number. Initial learning rate for the warmup_learning_rate: Initial learning rate for the warmup.
warmup.
name: Optional, name of warmup schedule. name: Optional, name of warmup schedule.
""" """
super(LinearWarmup, self).__init__() super(LinearWarmup, self).__init__()
...@@ -164,8 +163,8 @@ class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -164,8 +163,8 @@ class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Initialize configuration of the learning rate schedule. """Initialize configuration of the learning rate schedule.
Args: Args:
initial_learning_rate: A float, the initial learning rate. initial_learning_rate: The initial learning rate.
power: A float, the number of steps required for linear warmup. power: The order of the polynomial.
name: Optional, name of warmup schedule. name: Optional, name of warmup schedule.
""" """
super(DirectPowerDecay, self).__init__() super(DirectPowerDecay, self).__init__()
...@@ -209,10 +208,10 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -209,10 +208,10 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Initialize configuration of the learning rate schedule. """Initialize configuration of the learning rate schedule.
Args: Args:
initial_learning_rate: A float, the initial learning rate. initial_learning_rate: The initial learning rate.
total_decay_steps: The total number of steps for power + linear decay. total_decay_steps: The total number of steps for power + linear decay.
power: A float, the number of steps required for linear warmup. power: The order of the polynomial.
linear_decay_fraction: A float, in the last `linear_decay_fraction` steps, linear_decay_fraction: In the last `linear_decay_fraction` steps,
the learning rate will be multiplied by a linear decay. the learning rate will be multiplied by a linear decay.
name: Optional, name of warmup schedule. name: Optional, name of warmup schedule.
""" """
...@@ -244,3 +243,55 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -244,3 +243,55 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"linear_decay_fraction": self._linear_decay_fraction, "linear_decay_fraction": self._linear_decay_fraction,
"name": self._name, "name": self._name,
} }
class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Power learning rate decay with offset.
Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
Otherwise, learning rate equals to lr * (step - offset)^power.
"""
def __init__(self,
initial_learning_rate: float,
power: float = 1.0,
offset: int = 0,
pre_offset_learning_rate: float = 1.0e6,
name: str = "PowerDecayWithOffset"):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: The initial learning rate.
power: The order of the polynomial.
offset: The offset when computing the power decay.
pre_offset_learning_rate: The maximum learning rate we'll use.
name: Optional, name of warmup schedule.
"""
super(PowerDecayWithOffset, self).__init__()
self._initial_learning_rate = initial_learning_rate
self._power = power
self._offset = offset
self._pre_offset_lr = pre_offset_learning_rate
self._name = name
def __call__(self, step):
with tf.name_scope(self._name or "PowerDecayWithOffset"):
step = tf.cast(step, tf.float32)
lr_after_offset = tf.math.pow(
tf.math.maximum(step - self._offset, 1.0), self._power) * (
self._initial_learning_rate)
sign = tf.cast(step > self._offset, tf.float32)
lr_combined = (1.0 - sign) * self._pre_offset_lr + sign * lr_after_offset
# Power may give infinitely large LR. So cap it with pre_offset_lr.
return tf.math.minimum(lr_combined, self._pre_offset_lr)
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
"initial_learning_rate": self._initial_learning_rate,
"power": self._power,
"offset": self._offset,
"pre_offset_learning_rate": self._pre_offset_lr,
"name": self._name,
}
...@@ -40,6 +40,7 @@ LR_CLS = { ...@@ -40,6 +40,7 @@ LR_CLS = {
'cosine': tf.keras.experimental.CosineDecay, 'cosine': tf.keras.experimental.CosineDecay,
'power': lr_schedule.DirectPowerDecay, 'power': lr_schedule.DirectPowerDecay,
'power_linear': lr_schedule.PowerAndLinearDecay, 'power_linear': lr_schedule.PowerAndLinearDecay,
'power_with_offset': lr_schedule.PowerDecayWithOffset,
} }
WARMUP_CLS = { WARMUP_CLS = {
......
...@@ -365,6 +365,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -365,6 +365,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values: for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value) self.assertAlmostEqual(lr(step).numpy(), value)
def test_power_with_offset_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'power_with_offset',
'power_with_offset': {
'initial_learning_rate': 1.0,
'power': -1.0,
'offset': 10,
'pre_offset_learning_rate': 3.0,
}
}
}
expected_lr_step_values = [[1, 3.0], [10, 3.0], [20, 1. / 10.]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment