lr_schedule.py 18.6 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
16
"""Learning rate schedule classes."""

Yeqing Li's avatar
Yeqing Li committed
17
import math
Abdullah Rashwan's avatar
Abdullah Rashwan committed
18
19
20
21
22
from typing import Mapping, Any, Union, Optional

import tensorflow as tf


23
24
25
def _make_offset_wrapper(new_class_name: str, base_lr_class):
  """Generates a offset wrapper of learning rate schedule.

Yulv-git's avatar
Yulv-git committed
26
  It will returns a subclass of the `base_lr_class`, the subclass takes an
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  `offset` argument in the constructor. When the new class instance is called,
  the behavior is:
    new_class_object(step) = base_lr_class_object(step - offset)

  Example:
    CosineDecayWithOffset = _make_offset_wrapper(
                     'CosineDecayWithOffset', tf.keras.experimental.CosineDecay)
    # Use the lr:
    lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1,
                               decay_steps=1000)
    lr(101) # equals to tf.keras.experimental.CosineDecay(...)(101-100)

  Args:
    new_class_name: the name of the new class.
    base_lr_class: the base learning rate schedule class. Should be subclass of
      tf.keras.optimizers.schedules.LearningRateSchedule

  Returns:
    A new class (subclass of the base_lr_class) that can take an offset.
  """
  assert issubclass(base_lr_class,
                    tf.keras.optimizers.schedules.LearningRateSchedule), (
                        "base_lr_class should be subclass of keras "
                        f"LearningRateSchedule, got {base_lr_class}")

  # pylint: disable=protected-access,pointless-statement
  def offset_learning_rate_init(self, offset=0, **kwargs):
    """Construct learning rate schedule object.

    When this object is called, its behavior is
       self.__call__(step) == base_lr_class.__call__(step - offset)
    Args:
      self: this object.
      offset: The offset when computing the learning rate schedule.
      **kwargs: Pass through to base learning rate class constructor.
    """
    base_lr_class.__init__(self, **kwargs)
    self._offset = offset

  def offset_learning_rate_call(self, step):
    step = tf.cast(step - self._offset, tf.float32)
    return base_lr_class.__call__(self, step)

  # pylint: enable=protected-access,pointless-statement

  return type(
      new_class_name, (base_lr_class,), {
          "base_lr_class": base_lr_class,
          "__init__": offset_learning_rate_init,
          "__call__": offset_learning_rate_call
      })


PiecewiseConstantDecayWithOffset = _make_offset_wrapper(
    "PiecewiseConstantDecayWithOffset",
    tf.keras.optimizers.schedules.PiecewiseConstantDecay)
PolynomialDecayWithOffset = _make_offset_wrapper(
    "PolynomialDecayWithOffset", tf.keras.optimizers.schedules.PolynomialDecay)
ExponentialDecayWithOffset = _make_offset_wrapper(
    "ExponentialDecayWithOffset",
    tf.keras.optimizers.schedules.ExponentialDecay)
CosineDecayWithOffset = _make_offset_wrapper("CosineDecayWithOffset",
                                             tf.keras.experimental.CosineDecay)


Abdullah Rashwan's avatar
Abdullah Rashwan committed
92
93
94
class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Linear warmup schedule."""

95
96
97
98
99
  def __init__(self,
               after_warmup_lr_sched: Union[
                   tf.keras.optimizers.schedules.LearningRateSchedule, float],
               warmup_steps: int,
               warmup_learning_rate: float,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
100
101
102
103
104
105
106
107
108
109
110
111
112
               name: Optional[str] = None):
    """Add linear warmup schedule to a learning rate schedule.

    warmup_lr is the initial learning rate, the final learning rate of the
    init_warmup period is the initial learning rate of lr_schedule in use.
    The learning rate at each step linearly increased according to the following
    formula:
      learning_rate = warmup_lr + step / warmup_steps
                    * (final_warmup_lr - warmup_lr).
    Using warmup overrides the learning rate schedule by the number of warmup
    steps.

    Args:
113
114
      after_warmup_lr_sched: tf.keras.optimizers.schedules .LearningRateSchedule
        or a constant.
Le Hou's avatar
Le Hou committed
115
116
      warmup_steps: Number of the warmup steps.
      warmup_learning_rate: Initial learning rate for the warmup.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
117
118
      name: Optional, name of warmup schedule.
    """
Hongkun Yu's avatar
Hongkun Yu committed
119
    super().__init__()
Abdullah Rashwan's avatar
Abdullah Rashwan committed
120
121
122
123
124
125
126
127
    self._name = name
    self._after_warmup_lr_sched = after_warmup_lr_sched
    self._warmup_steps = warmup_steps
    self._init_warmup_lr = warmup_learning_rate
    if isinstance(after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      self._final_warmup_lr = after_warmup_lr_sched(warmup_steps)
    else:
128
      self._final_warmup_lr = tf.cast(after_warmup_lr_sched, dtype=tf.float32)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

  def __call__(self, step: int):

    global_step = tf.cast(step, dtype=tf.float32)

    linear_warmup_lr = (
        self._init_warmup_lr + global_step / self._warmup_steps *
        (self._final_warmup_lr - self._init_warmup_lr))

    if isinstance(self._after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      after_warmup_lr = self._after_warmup_lr_sched(step)
    else:
      after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)

    lr = tf.cond(global_step < self._warmup_steps,
                 lambda: linear_warmup_lr,
                 lambda: after_warmup_lr)
    return lr

  def get_config(self) -> Mapping[str, Any]:
    if isinstance(self._after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
152
153
      config = {
          "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()}  # pytype: disable=attribute-error
Abdullah Rashwan's avatar
Abdullah Rashwan committed
154
    else:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
155
      config = {"after_warmup_lr_sched": self._after_warmup_lr_sched}  # pytype: disable=attribute-error
Abdullah Rashwan's avatar
Abdullah Rashwan committed
156
157
158
159

    config.update({
        "warmup_steps": self._warmup_steps,
        "warmup_learning_rate": self._init_warmup_lr,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
160
161
162
163
164
165
        "name": self._name
    })
    return config


class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
166
  """Applies polynomial warmup schedule on a given learning rate decay schedule."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
167
168
169
170
171
172
173

  def __init__(self,
               after_warmup_lr_sched: Union[
                   tf.keras.optimizers.schedules.LearningRateSchedule, float],
               warmup_steps: int,
               power: float = 1.0,
               name: str = "PolynomialWarmup"):
Hongkun Yu's avatar
Hongkun Yu committed
174
    super().__init__()
Abdullah Rashwan's avatar
Abdullah Rashwan committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    if isinstance(after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      self._initial_learning_rate = after_warmup_lr_sched(warmup_steps)
    else:
      self._initial_learning_rate = tf.cast(
          after_warmup_lr_sched, dtype=tf.float32)

    self._warmup_steps = warmup_steps
    self._power = power
    self._after_warmup_lr_sched = after_warmup_lr_sched
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PolynomialWarmUp") as name:
      # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
      # learning rate will be `global_step/num_warmup_steps * init_lr`.
      global_step_float = tf.cast(step, tf.float32)
      warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)
193
194
195
196
197
198
199
200

      if self._warmup_steps <= 0:
        warmup_percent_done = 1.0
      else:
        # A zero `step` may cause Inf. So make `step` positive.
        step_non_zero = tf.math.maximum(global_step_float, 1.0)
        warmup_percent_done = step_non_zero / warmup_steps_float

Abdullah Rashwan's avatar
Abdullah Rashwan committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
      warmup_learning_rate = (
          self._initial_learning_rate *
          tf.math.pow(warmup_percent_done, self._power))

      if isinstance(self._after_warmup_lr_sched,
                    tf.keras.optimizers.schedules.LearningRateSchedule):
        after_warmup_lr = self._after_warmup_lr_sched(step)
      else:
        after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)

      return tf.cond(
          global_step_float < warmup_steps_float,
          lambda: warmup_learning_rate,
          lambda: after_warmup_lr,
          name=name)

  def get_config(self) -> Mapping[str, Any]:
    if isinstance(self._after_warmup_lr_sched,
                  tf.keras.optimizers.schedules.LearningRateSchedule):
      config = {
          "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()}  # pytype: disable=attribute-error
    else:
      config = {"after_warmup_lr_sched": self._after_warmup_lr_sched}  # pytype: disable=attribute-error

    config.update({
Le Hou's avatar
Le Hou committed
226
        "warmup_steps": self._warmup_steps,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
227
228
        "power": self._power,
        "name": self._name
Abdullah Rashwan's avatar
Abdullah Rashwan committed
229
230
    })
    return config
231
232
233
234
235
236
237
238
239
240
241
242


class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Learning rate schedule follows lr * (step)^power."""

  def __init__(self,
               initial_learning_rate: float,
               power: float = 1.0,
               name: str = "DirectPowerDecay"):
    """Initialize configuration of the learning rate schedule.

    Args:
Le Hou's avatar
Le Hou committed
243
244
      initial_learning_rate: The initial learning rate.
      power: The order of the polynomial.
245
      name: Optional, name of learning rate schedule.
246
    """
Hongkun Yu's avatar
Hongkun Yu committed
247
    super().__init__()
248
249
250
251
252
253
254
255
    self._initial_learning_rate = initial_learning_rate
    self._power = power
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "DirectPowerDecay"):
      step = tf.cast(step, tf.float32)
      learning_rate = self._initial_learning_rate
Le Hou's avatar
Le Hou committed
256
257
258
      # A zero `step` may cause Inf. So make `step` positive.
      step_non_zero = tf.math.maximum(step, 1.0)
      learning_rate *= tf.math.pow(step_non_zero, self._power)
259
260
261
262
263
264
265
266
267
      return learning_rate

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "power": self._power,
        "name": self._name,
    }
Le Hou's avatar
Le Hou committed
268
269
270
271
272


class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Learning rate schedule with multiplied by linear decay at the end.

273
274
275
276
277
  The schedule has the following behavoir.
  Let offset_step = step - offset.
  1) offset_step < 0, the actual learning rate equals initial_learning_rate.
  2) offset_step <= total_decay_steps * (1 - linear_decay_fraction), the
  actual learning rate equals lr * offset_step^power.
Yuexin Wu's avatar
Yuexin Wu committed
278
  3) total_decay_steps * (1 - linear_decay_fraction) <= offset_step <
279
280
281
  total_decay_steps, the actual learning rate equals lr * offset_step^power *
  (total_decay_steps - offset_step) / (total_decay_steps *
  linear_decay_fraction).
Yuexin Wu's avatar
Yuexin Wu committed
282
  4) offset_step >= total_decay_steps, the actual learning rate equals zero.
Le Hou's avatar
Le Hou committed
283
284
285
286
287
288
289
  """

  def __init__(self,
               initial_learning_rate: float,
               total_decay_steps: int,
               power: float = 1.0,
               linear_decay_fraction: float = 0.1,
290
               offset: int = 0,
Le Hou's avatar
Le Hou committed
291
292
293
294
               name: str = "PowerAndLinearDecay"):
    """Initialize configuration of the learning rate schedule.

    Args:
Le Hou's avatar
Le Hou committed
295
      initial_learning_rate: The initial learning rate.
Le Hou's avatar
Le Hou committed
296
      total_decay_steps: The total number of steps for power + linear decay.
Le Hou's avatar
Le Hou committed
297
      power: The order of the polynomial.
298
299
300
301
      linear_decay_fraction: In the last `linear_decay_fraction` steps, the
        learning rate will be multiplied by a linear decay.
      offset: The offset applied to steps.
      name: Optional, name of learning rate schedule.
Le Hou's avatar
Le Hou committed
302
    """
Hongkun Yu's avatar
Hongkun Yu committed
303
    super().__init__()
Le Hou's avatar
Le Hou committed
304
305
306
307
    self._initial_learning_rate = initial_learning_rate
    self._total_decay_steps = total_decay_steps
    self._power = power
    self._linear_decay_fraction = linear_decay_fraction
308
    self._offset = offset
Le Hou's avatar
Le Hou committed
309
310
311
312
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PowerAndLinearDecay"):
313
      step = tf.cast(step - self._offset, tf.float32)
Le Hou's avatar
Le Hou committed
314
      learning_rate = self._initial_learning_rate
315
316
317
318
      # A zero `step` may cause Inf. So make `step` positive.
      step_non_zero = tf.math.maximum(step, 1.0)
      learning_rate *= tf.math.pow(step_non_zero, self._power)
      if self._total_decay_steps * self._linear_decay_fraction > 0:
Le Hou's avatar
Le Hou committed
319
320
321
322
323
324
325
326
327
328
329
330
331
        learning_rate *= tf.minimum(
            1.0, (self._total_decay_steps - step) /
            (self._total_decay_steps * self._linear_decay_fraction))
        learning_rate = tf.maximum(0.0, learning_rate)
      return learning_rate

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "total_decay_steps": self._total_decay_steps,
        "power": self._power,
        "linear_decay_fraction": self._linear_decay_fraction,
332
        "offset": self._offset,
Le Hou's avatar
Le Hou committed
333
334
        "name": self._name,
    }
Le Hou's avatar
Le Hou committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356


class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Power learning rate decay with offset.

  Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
  Otherwise, learning rate equals to lr * (step - offset)^power.
  """

  def __init__(self,
               initial_learning_rate: float,
               power: float = 1.0,
               offset: int = 0,
               pre_offset_learning_rate: float = 1.0e6,
               name: str = "PowerDecayWithOffset"):
    """Initialize configuration of the learning rate schedule.

    Args:
      initial_learning_rate: The initial learning rate.
      power: The order of the polynomial.
      offset: The offset when computing the power decay.
      pre_offset_learning_rate: The maximum learning rate we'll use.
357
      name: Optional, name of learning rate schedule.
Le Hou's avatar
Le Hou committed
358
    """
Hongkun Yu's avatar
Hongkun Yu committed
359
    super().__init__()
Le Hou's avatar
Le Hou committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    self._initial_learning_rate = initial_learning_rate
    self._power = power
    self._offset = offset
    self._pre_offset_lr = pre_offset_learning_rate
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PowerDecayWithOffset"):
      step = tf.cast(step, tf.float32)
      lr_after_offset = tf.math.pow(
          tf.math.maximum(step - self._offset, 1.0), self._power) * (
              self._initial_learning_rate)

      sign = tf.cast(step > self._offset, tf.float32)
      lr_combined = (1.0 - sign) * self._pre_offset_lr + sign * lr_after_offset
      # Power may give infinitely large LR. So cap it with pre_offset_lr.
      return tf.math.minimum(lr_combined, self._pre_offset_lr)

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "power": self._power,
        "offset": self._offset,
        "pre_offset_learning_rate": self._pre_offset_lr,
        "name": self._name,
    }
Yeqing Li's avatar
Yeqing Li committed
387
388


389
class StepCosineDecayWithOffset(
Yeqing Li's avatar
Yeqing Li committed
390
391
392
    tf.keras.optimizers.schedules.LearningRateSchedule):
  """Stepwise cosine learning rate decay with offset.

393
  Learning rate is equivalent to one or more cosine decay(s) starting and
Yeqing Li's avatar
Yeqing Li committed
394
395
396
397
398
399
400
401
  ending at each interval.

  ExampleL

    ```python
    boundaries: [100000, 110000]
    values: [1.0, 0.5]
    lr_decayed_fn = (
402
    lr_schedule.StepCosineDecayWithOffset(
Yeqing Li's avatar
Yeqing Li committed
403
404
405
406
407
408
409
410
411
412
413
414
        boundaries,
        values))
    ```

    from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
    from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
  """

  def __init__(self,
               boundaries,
               values,
               offset: int = 0,
415
               name: str = "StepCosineDecayWithOffset"):
Yeqing Li's avatar
Yeqing Li committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    """Initialize configuration of the learning rate schedule.

    Args:
      boundaries: A list of `Tensor`s or `int`s with strictly
        increasing entries, and with all elements having the same type as the
        optimizer step.
      values: A list of `Tensor`s or `float`s that specifies the
        values for the intervals defined by `boundaries`. It should have one
        more element than `boundaries`, and all elements should have the same
        type.
      offset: The offset when computing the power decay.
      name: Optional, name of learning rate schedule.
    """
    super().__init__()
    self.values = values
    self.boundaries = boundaries
    self.offset = offset
    self.name = name

    if len(self.values) < 1:
      raise ValueError(f"Expect non empty {self.values}")
    if len(self.boundaries) != len(self.values):
      raise ValueError(
          "Boundaries length is equal to learning rate levels length"
          f"{len(self.boundaries)} != {len(self.values)}")

    self.total_steps = (
        [boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1)
        ] + [0])

  def __call__(self, global_step):
447
    with tf.name_scope(self.name or "StepCosineDecayWithOffset"):
Yeqing Li's avatar
Yeqing Li committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
      global_step = tf.cast(global_step - self.offset, tf.float32)
      lr_levels = self.values
      lr_steps = self.boundaries
      level_total_steps = self.total_steps
      num_levels = len(lr_levels)

      init_lr = lr_levels[0]
      next_init_lr = lr_levels[1] if num_levels > 1 else 0.

      init_total_steps = level_total_steps[0]

      cosine_learning_rate = ((init_lr - next_init_lr) * (tf.cos(
          tf.constant(math.pi) * (global_step) /
          (init_total_steps)) + 1.0) / 2.0 + next_init_lr)
      learning_rate = cosine_learning_rate
      tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
                                cosine_learning_rate)
      tf.compat.v1.logging.info("DEBUG lr %r next lr %r inittotalstep %r",
                                init_lr, next_init_lr, init_total_steps)

      for i in range(1, num_levels):
        next_init_lr = lr_levels[i]
        next_start_step = lr_steps[i]
        next_total_steps = level_total_steps[i]
        next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0.

        tf.compat.v1.logging.info(
            "DEBUG step %r nilr %r nss %r nts %r nnilr %r", global_step,
            next_init_lr, next_start_step, next_total_steps, next_next_init_lr)
        next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) *
                                     (tf.cos(
                                         tf.constant(math.pi) *
                                         (global_step - next_start_step) /
                                         (next_total_steps)) + 1.0) / 2.0 +
                                     next_next_init_lr)
        learning_rate = tf.where(global_step >= next_start_step,
                                 next_cosine_learning_rate, learning_rate)
        tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
                                  next_cosine_learning_rate)

    return learning_rate

  def get_config(self):
    return {
        "boundaries": self.boundaries,
        "values": self.values,
        "offset": self.offset,
        "name": self.name
    }