lr_schedule.py 11.5 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""Learning rate schedule classes."""

from typing import Mapping, Any, Union, Optional

import tensorflow as tf


class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Linear warmup schedule."""

  def __init__(self, after_warmup_lr_sched: Union[
      tf.keras.optimizers.schedules.LearningRateSchedule, float],
               warmup_steps: int, warmup_learning_rate: float,
               name: Optional[str] = None):
    """Add linear warmup schedule to a learning rate schedule.

    warmup_lr is the initial learning rate, the final learning rate of the
    init_warmup period is the initial learning rate of lr_schedule in use.
    The learning rate at each step linearly increased according to the following
    formula:
      learning_rate = warmup_lr + step / warmup_steps
                    * (final_warmup_lr - warmup_lr).
    Using warmup overrides the learning rate schedule by the number of warmup
    steps.

    Args:
      after_warmup_lr_sched: tf.keras.optimizers.schedules
                                .LearningRateSchedule or a constant.
Le Hou's avatar
Le Hou committed
43
44
      warmup_steps: Number of the warmup steps.
      warmup_learning_rate: Initial learning rate for the warmup.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
      name: Optional, name of warmup schedule.
    """
    super(LinearWarmup, self).__init__()
    self._name = name
    self._after_warmup_lr_sched = after_warmup_lr_sched
    self._warmup_steps = warmup_steps
    self._init_warmup_lr = warmup_learning_rate
    if isinstance(after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      self._final_warmup_lr = after_warmup_lr_sched(warmup_steps)
    else:
      self._final_warmup_lr = tf.cast(
          after_warmup_lr_sched, dtype=tf.float32)

  def __call__(self, step: int):

    global_step = tf.cast(step, dtype=tf.float32)

    linear_warmup_lr = (
        self._init_warmup_lr + global_step / self._warmup_steps *
        (self._final_warmup_lr - self._init_warmup_lr))

    if isinstance(self._after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      after_warmup_lr = self._after_warmup_lr_sched(step)
    else:
      after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)

    lr = tf.cond(global_step < self._warmup_steps,
                 lambda: linear_warmup_lr,
                 lambda: after_warmup_lr)
    return lr

  def get_config(self) -> Mapping[str, Any]:
    if isinstance(self._after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
81
82
      config = {
          "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()}  # pytype: disable=attribute-error
Abdullah Rashwan's avatar
Abdullah Rashwan committed
83
    else:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
84
      config = {"after_warmup_lr_sched": self._after_warmup_lr_sched}  # pytype: disable=attribute-error
Abdullah Rashwan's avatar
Abdullah Rashwan committed
85
86
87
88

    config.update({
        "warmup_steps": self._warmup_steps,
        "warmup_learning_rate": self._init_warmup_lr,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        "name": self._name
    })
    return config


class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Applies polynomial warmup schedule on a given learning rate decay schedule.
  """

  def __init__(self,
               after_warmup_lr_sched: Union[
                   tf.keras.optimizers.schedules.LearningRateSchedule, float],
               warmup_steps: int,
               power: float = 1.0,
               name: str = "PolynomialWarmup"):
    super(PolynomialWarmUp, self).__init__()
    if isinstance(after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      self._initial_learning_rate = after_warmup_lr_sched(warmup_steps)
    else:
      self._initial_learning_rate = tf.cast(
          after_warmup_lr_sched, dtype=tf.float32)

    self._warmup_steps = warmup_steps
    self._power = power
    self._after_warmup_lr_sched = after_warmup_lr_sched
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PolynomialWarmUp") as name:
      # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
      # learning rate will be `global_step/num_warmup_steps * init_lr`.
      global_step_float = tf.cast(step, tf.float32)
      warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)
123
124
125
126
127
128
129
130

      if self._warmup_steps <= 0:
        warmup_percent_done = 1.0
      else:
        # A zero `step` may cause Inf. So make `step` positive.
        step_non_zero = tf.math.maximum(global_step_float, 1.0)
        warmup_percent_done = step_non_zero / warmup_steps_float

Abdullah Rashwan's avatar
Abdullah Rashwan committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
      warmup_learning_rate = (
          self._initial_learning_rate *
          tf.math.pow(warmup_percent_done, self._power))

      if isinstance(self._after_warmup_lr_sched,
                    tf.keras.optimizers.schedules.LearningRateSchedule):
        after_warmup_lr = self._after_warmup_lr_sched(step)
      else:
        after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)

      return tf.cond(
          global_step_float < warmup_steps_float,
          lambda: warmup_learning_rate,
          lambda: after_warmup_lr,
          name=name)

  def get_config(self) -> Mapping[str, Any]:
    if isinstance(self._after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      config = {
          "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()}  # pytype: disable=attribute-error
    else:
      config = {"after_warmup_lr_sched": self._after_warmup_lr_sched}  # pytype: disable=attribute-error

    config.update({
Le Hou's avatar
Le Hou committed
156
        "warmup_steps": self._warmup_steps,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
157
158
        "power": self._power,
        "name": self._name
Abdullah Rashwan's avatar
Abdullah Rashwan committed
159
160
    })
    return config
161
162
163
164
165
166
167
168
169
170
171
172


class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Learning rate schedule follows lr * (step)^power."""

  def __init__(self,
               initial_learning_rate: float,
               power: float = 1.0,
               name: str = "DirectPowerDecay"):
    """Initialize configuration of the learning rate schedule.

    Args:
Le Hou's avatar
Le Hou committed
173
174
      initial_learning_rate: The initial learning rate.
      power: The order of the polynomial.
175
176
177
178
179
180
181
182
183
184
185
      name: Optional, name of warmup schedule.
    """
    super(DirectPowerDecay, self).__init__()
    self._initial_learning_rate = initial_learning_rate
    self._power = power
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "DirectPowerDecay"):
      step = tf.cast(step, tf.float32)
      learning_rate = self._initial_learning_rate
Le Hou's avatar
Le Hou committed
186
187
188
      # A zero `step` may cause Inf. So make `step` positive.
      step_non_zero = tf.math.maximum(step, 1.0)
      learning_rate *= tf.math.pow(step_non_zero, self._power)
189
190
191
192
193
194
195
196
197
      return learning_rate

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "power": self._power,
        "name": self._name,
    }
Le Hou's avatar
Le Hou committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217


class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Learning rate schedule with multiplied by linear decay at the end.

  follows lr * (step)^power for the first total_decay_steps *
  (1 - linear_decay_fraction) steps, and follows lr * (step)^power *
  (total_decay_steps - step) / (total_decay_steps * linear_decay_fraction)
  for the rest of the steps.
  """

  def __init__(self,
               initial_learning_rate: float,
               total_decay_steps: int,
               power: float = 1.0,
               linear_decay_fraction: float = 0.1,
               name: str = "PowerAndLinearDecay"):
    """Initialize configuration of the learning rate schedule.

    Args:
Le Hou's avatar
Le Hou committed
218
      initial_learning_rate: The initial learning rate.
Le Hou's avatar
Le Hou committed
219
      total_decay_steps: The total number of steps for power + linear decay.
Le Hou's avatar
Le Hou committed
220
221
      power: The order of the polynomial.
      linear_decay_fraction: In the last `linear_decay_fraction` steps,
Le Hou's avatar
Le Hou committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        the learning rate will be multiplied by a linear decay.
      name: Optional, name of warmup schedule.
    """
    super(PowerAndLinearDecay, self).__init__()
    self._initial_learning_rate = initial_learning_rate
    self._total_decay_steps = total_decay_steps
    self._power = power
    self._linear_decay_fraction = linear_decay_fraction
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PowerAndLinearDecay"):
      step = tf.cast(step, tf.float32)
      learning_rate = self._initial_learning_rate
236
237
238
239
      # A zero `step` may cause Inf. So make `step` positive.
      step_non_zero = tf.math.maximum(step, 1.0)
      learning_rate *= tf.math.pow(step_non_zero, self._power)
      if self._total_decay_steps * self._linear_decay_fraction > 0:
Le Hou's avatar
Le Hou committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        learning_rate *= tf.minimum(
            1.0, (self._total_decay_steps - step) /
            (self._total_decay_steps * self._linear_decay_fraction))
        learning_rate = tf.maximum(0.0, learning_rate)
      return learning_rate

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "total_decay_steps": self._total_decay_steps,
        "power": self._power,
        "linear_decay_fraction": self._linear_decay_fraction,
        "name": self._name,
    }
Le Hou's avatar
Le Hou committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306


class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Power learning rate decay with offset.

  Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
  Otherwise, learning rate equals to lr * (step - offset)^power.
  """

  def __init__(self,
               initial_learning_rate: float,
               power: float = 1.0,
               offset: int = 0,
               pre_offset_learning_rate: float = 1.0e6,
               name: str = "PowerDecayWithOffset"):
    """Initialize configuration of the learning rate schedule.

    Args:
      initial_learning_rate: The initial learning rate.
      power: The order of the polynomial.
      offset: The offset when computing the power decay.
      pre_offset_learning_rate: The maximum learning rate we'll use.
      name: Optional, name of warmup schedule.
    """
    super(PowerDecayWithOffset, self).__init__()
    self._initial_learning_rate = initial_learning_rate
    self._power = power
    self._offset = offset
    self._pre_offset_lr = pre_offset_learning_rate
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PowerDecayWithOffset"):
      step = tf.cast(step, tf.float32)
      lr_after_offset = tf.math.pow(
          tf.math.maximum(step - self._offset, 1.0), self._power) * (
              self._initial_learning_rate)

      sign = tf.cast(step > self._offset, tf.float32)
      lr_combined = (1.0 - sign) * self._pre_offset_lr + sign * lr_after_offset
      # Power may give infinitely large LR. So cap it with pre_offset_lr.
      return tf.math.minimum(lr_combined, self._pre_offset_lr)

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "power": self._power,
        "offset": self._offset,
        "pre_offset_learning_rate": self._pre_offset_lr,
        "name": self._name,
    }