learning_rates.py 3.73 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

15
16
17
18
19
20
21
22
23
"""Learning rate schedule."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

import numpy as np
24
import tensorflow as tf
25
26
27
from official.modeling.hyperparams import params_dict


Hongkun Yu's avatar
Hongkun Yu committed
28
29
class StepLearningRateWithLinearWarmup(
    tf.keras.optimizers.schedules.LearningRateSchedule):
30
31
  """Class to generate learning rate tensor."""

Pengchong Jin's avatar
Pengchong Jin committed
32
  def __init__(self, total_steps, params):
33
34
    """Creates the step learning rate tensor with linear warmup."""
    super(StepLearningRateWithLinearWarmup, self).__init__()
Pengchong Jin's avatar
Pengchong Jin committed
35
    self._total_steps = total_steps
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    assert isinstance(params, (dict, params_dict.ParamsDict))
    if isinstance(params, dict):
      params = params_dict.ParamsDict(params)
    self._params = params

  def __call__(self, global_step):
    warmup_lr = self._params.warmup_learning_rate
    warmup_steps = self._params.warmup_steps
    init_lr = self._params.init_learning_rate
    lr_levels = self._params.learning_rate_levels
    lr_steps = self._params.learning_rate_steps
    linear_warmup = (
        warmup_lr + tf.cast(global_step, dtype=tf.float32) / warmup_steps *
        (init_lr - warmup_lr))
    learning_rate = tf.where(global_step < warmup_steps, linear_warmup, init_lr)

    for next_learning_rate, start_step in zip(lr_levels, lr_steps):
      learning_rate = tf.where(global_step >= start_step, next_learning_rate,
                               learning_rate)
    return learning_rate

  def get_config(self):
    return {'_params': self._params.as_dict()}


Hongkun Yu's avatar
Hongkun Yu committed
61
62
class CosineLearningRateWithLinearWarmup(
    tf.keras.optimizers.schedules.LearningRateSchedule):
63
64
  """Class to generate learning rate tensor."""

Pengchong Jin's avatar
Pengchong Jin committed
65
  def __init__(self, total_steps, params):
66
67
    """Creates the consine learning rate tensor with linear warmup."""
    super(CosineLearningRateWithLinearWarmup, self).__init__()
Pengchong Jin's avatar
Pengchong Jin committed
68
    self._total_steps = total_steps
69
70
71
72
73
74
75
76
77
78
    assert isinstance(params, (dict, params_dict.ParamsDict))
    if isinstance(params, dict):
      params = params_dict.ParamsDict(params)
    self._params = params

  def __call__(self, global_step):
    global_step = tf.cast(global_step, dtype=tf.float32)
    warmup_lr = self._params.warmup_learning_rate
    warmup_steps = self._params.warmup_steps
    init_lr = self._params.init_learning_rate
Pengchong Jin's avatar
Pengchong Jin committed
79
    total_steps = self._total_steps
80
81
82
83
84
85
86
87
88
89
90
91
92
    linear_warmup = (
        warmup_lr + global_step / warmup_steps * (init_lr - warmup_lr))
    cosine_learning_rate = (
        init_lr * (tf.cos(np.pi * (global_step - warmup_steps) /
                          (total_steps - warmup_steps)) + 1.0) / 2.0)
    learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
                             cosine_learning_rate)
    return learning_rate

  def get_config(self):
    return {'_params': self._params.as_dict()}


Pengchong Jin's avatar
Pengchong Jin committed
93
def learning_rate_generator(total_steps, params):
94
95
  """The learning rate function generator."""
  if params.type == 'step':
Pengchong Jin's avatar
Pengchong Jin committed
96
    return StepLearningRateWithLinearWarmup(total_steps, params)
97
  elif params.type == 'cosine':
Pengchong Jin's avatar
Pengchong Jin committed
98
    return CosineLearningRateWithLinearWarmup(total_steps, params)
99
100
  else:
    raise ValueError('Unsupported learning rate type: {}.'.format(params.type))