"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "4b843668c3f0af069bbe7c4c09cdece952f3fec4"
Commit afe4802e authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Add step offset support for PowerAndLinearDecay.

PiperOrigin-RevId: 372675330
parent 6a7b4d1a
......@@ -149,21 +149,35 @@ class DirectPowerLrConfig(base_config.Config):
class PowerAndLinearDecayLrConfig(base_config.Config):
"""Configuration for DirectPower learning rate decay.
This class configures a schedule following follows lr * (step)^power for the
first total_decay_steps * (1 - linear_decay_fraction) steps, and follows
lr * (step)^power * (total_decay_steps - step) / (total_decay_steps *
linear_decay_fraction) for the rest of the steps.
The schedule has the following behavoir.
Let offset_step = step - offset.
1) offset_step < 0, the actual learning rate equals initial_learning_rate.
2) offset_step <= total_decay_steps * (1 - linear_decay_fraction), the
actual learning rate equals lr * offset_step^power.
3) total_decay_steps * (1 - linear_decay_fraction) < offset_step <
total_decay_steps, the actual learning rate equals lr * offset_step^power *
(total_decay_steps - offset_step) / (total_decay_steps *
linear_decay_fraction).
4) offset_step > total_decay_steps, the actual learning rate equals zero.
Attributes:
name: The name of the learning rate schedule. Defaults to DirectPowerDecay.
name: The name of the learning rate schedule. Defaults to
PowerAndLinearDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
total_decay_steps: An int. The total number of steps for power + linear
decay. Defaults to None.
power: A float. The order of the polynomial. Defaults to -0.5, for sqrt
decay.
linear_decay_fraction: A float. In the last `linear_decay_fraction` steps,
the learning rate will be multiplied by a linear decay. Defaults to 0.1.
offset: An int. The offset applied to steps. Defaults to 0.
"""
name: str = 'PowerAndLinearDecay'
initial_learning_rate: Optional[float] = None
total_decay_steps: Optional[int] = None
power: float = -0.5
linear_decay_fraction: float = 0.1
offset: int = 0
@dataclasses.dataclass
......@@ -174,8 +188,8 @@ class PowerDecayWithOffsetLrConfig(base_config.Config):
Otherwise, learning rate equals to lr * (step - offset)^power.
Attributes:
name: The name of the learning rate schedule.
Defaults to PowerDecayWithOffset.
name: The name of the learning rate schedule. Defaults to
PowerDecayWithOffset.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
offset: An integer. Power decay happens after `offset` steps.
......
......@@ -22,9 +22,11 @@ import tensorflow as tf
class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Linear warmup schedule."""
def __init__(self, after_warmup_lr_sched: Union[
tf.keras.optimizers.schedules.LearningRateSchedule, float],
warmup_steps: int, warmup_learning_rate: float,
def __init__(self,
after_warmup_lr_sched: Union[
tf.keras.optimizers.schedules.LearningRateSchedule, float],
warmup_steps: int,
warmup_learning_rate: float,
name: Optional[str] = None):
"""Add linear warmup schedule to a learning rate schedule.
......@@ -38,8 +40,8 @@ class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
steps.
Args:
after_warmup_lr_sched: tf.keras.optimizers.schedules
.LearningRateSchedule or a constant.
after_warmup_lr_sched: tf.keras.optimizers.schedules .LearningRateSchedule
or a constant.
warmup_steps: Number of the warmup steps.
warmup_learning_rate: Initial learning rate for the warmup.
name: Optional, name of warmup schedule.
......@@ -53,8 +55,7 @@ class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
tf.keras.optimizers.schedules.LearningRateSchedule):
self._final_warmup_lr = after_warmup_lr_sched(warmup_steps)
else:
self._final_warmup_lr = tf.cast(
after_warmup_lr_sched, dtype=tf.float32)
self._final_warmup_lr = tf.cast(after_warmup_lr_sched, dtype=tf.float32)
def __call__(self, step: int):
......@@ -92,8 +93,7 @@ class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies polynomial warmup schedule on a given learning rate decay schedule.
"""
"""Applies polynomial warmup schedule on a given learning rate decay schedule."""
def __init__(self,
after_warmup_lr_sched: Union[
......@@ -172,7 +172,7 @@ class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
Args:
initial_learning_rate: The initial learning rate.
power: The order of the polynomial.
name: Optional, name of warmup schedule.
name: Optional, name of learning rate schedule.
"""
super().__init__()
self._initial_learning_rate = initial_learning_rate
......@@ -200,10 +200,16 @@ class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule with multiplied by linear decay at the end.
follows lr * (step)^power for the first total_decay_steps *
(1 - linear_decay_fraction) steps, and follows lr * (step)^power *
(total_decay_steps - step) / (total_decay_steps * linear_decay_fraction)
for the rest of the steps.
The schedule has the following behavoir.
Let offset_step = step - offset.
1) offset_step < 0, the actual learning rate equals initial_learning_rate.
2) offset_step <= total_decay_steps * (1 - linear_decay_fraction), the
actual learning rate equals lr * offset_step^power.
3) total_decay_steps * (1 - linear_decay_fraction) < offset_step <
total_decay_steps, the actual learning rate equals lr * offset_step^power *
(total_decay_steps - offset_step) / (total_decay_steps *
linear_decay_fraction).
4) offset_step > total_decay_steps, the actual learning rate equals zero.
"""
def __init__(self,
......@@ -211,6 +217,7 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
total_decay_steps: int,
power: float = 1.0,
linear_decay_fraction: float = 0.1,
offset: int = 0,
name: str = "PowerAndLinearDecay"):
"""Initialize configuration of the learning rate schedule.
......@@ -218,20 +225,22 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
initial_learning_rate: The initial learning rate.
total_decay_steps: The total number of steps for power + linear decay.
power: The order of the polynomial.
linear_decay_fraction: In the last `linear_decay_fraction` steps,
the learning rate will be multiplied by a linear decay.
name: Optional, name of warmup schedule.
linear_decay_fraction: In the last `linear_decay_fraction` steps, the
learning rate will be multiplied by a linear decay.
offset: The offset applied to steps.
name: Optional, name of learning rate schedule.
"""
super().__init__()
self._initial_learning_rate = initial_learning_rate
self._total_decay_steps = total_decay_steps
self._power = power
self._linear_decay_fraction = linear_decay_fraction
self._offset = offset
self._name = name
def __call__(self, step):
with tf.name_scope(self._name or "PowerAndLinearDecay"):
step = tf.cast(step, tf.float32)
step = tf.cast(step - self._offset, tf.float32)
learning_rate = self._initial_learning_rate
# A zero `step` may cause Inf. So make `step` positive.
step_non_zero = tf.math.maximum(step, 1.0)
......@@ -250,6 +259,7 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"total_decay_steps": self._total_decay_steps,
"power": self._power,
"linear_decay_fraction": self._linear_decay_fraction,
"offset": self._offset,
"name": self._name,
}
......@@ -274,7 +284,7 @@ class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule):
power: The order of the polynomial.
offset: The offset when computing the power decay.
pre_offset_learning_rate: The maximum learning rate we'll use.
name: Optional, name of warmup schedule.
name: Optional, name of learning rate schedule.
"""
super().__init__()
self._initial_learning_rate = initial_learning_rate
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lr_schedule."""
from absl.testing import parameterized
import tensorflow as tf
from official.modeling.optimization import lr_schedule
class PowerAndLinearDecayTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='power_only',
init_lr=1.0,
power=-1.0,
linear_decay_fraction=0.0,
total_decay_steps=100,
offset=0,
expected=[[0, 1.0], [1, 1.0], [40, 1. / 40.], [60, 1. / 60],
[100, 1. / 100]]),
dict(
testcase_name='linear_only',
init_lr=1.0,
power=0.0,
linear_decay_fraction=1.0,
total_decay_steps=100,
offset=0,
expected=[[0, 1.0], [1, 0.99], [40, 0.6], [60, 0.4], [100, 0.0]]),
dict(
testcase_name='general',
init_lr=1.0,
power=-1.0,
linear_decay_fraction=0.5,
total_decay_steps=100,
offset=0,
expected=[[0, 1.0], [1, 1.0], [40, 1. / 40.],
[60, 1. / 60. * 0.8], [100, 0.0]]),
dict(
testcase_name='offset',
init_lr=1.0,
power=-1.0,
linear_decay_fraction=0.5,
total_decay_steps=100,
offset=90,
expected=[[0, 1.0], [90, 1.0], [91, 1.0], [130, 1. / 40.],
[150, 1. / 60. * 0.8], [190, 0.0], [200, 0.0]]),
)
def test_power_linear_lr_schedule(self, init_lr, power, linear_decay_fraction,
total_decay_steps, offset, expected):
lr = lr_schedule.PowerAndLinearDecay(
initial_learning_rate=init_lr,
power=power,
linear_decay_fraction=linear_decay_fraction,
total_decay_steps=total_decay_steps,
offset=offset)
for step, value in expected:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__':
tf.test.main()
......@@ -107,6 +107,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
# TODO(b/187559334) refactor lr_schedule tests into `lr_schedule_test.py`.
def test_stepwise_lr_schedule(self):
params = {
'optimizer': {
......@@ -352,6 +355,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
'power': -1.0,
'linear_decay_fraction': 0.5,
'total_decay_steps': 100,
'offset': 0,
}
}
}
......@@ -390,6 +394,5 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment