schedules.py 4.07 KB
Newer Older
burchim's avatar
burchim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

class constant_learning_rate_scheduler:

    def __init__(self, optimizer, lr_value):

        # Model Optimizer
        self.optimizer = optimizer

        # Model Step
        self.model_step = -1

        # Scheduler Params
        self.lr_value = lr_value

    def step(self):
        
        # Update Model Step
        self.model_step += 1
        s = self.model_step + 1

        # Update LR
        self.optimizer.param_groups[0]['lr'] = self.lr_value

class constant_with_decay_learning_rate_scheduler:

    def __init__(self, optimizer, lr_values, decay_steps):

        # Model Optimizer
        self.optimizer = optimizer

        # Model Step
        self.model_step = -1

        # Scheduler Params
        self.lr_values = lr_values
        self.decay_steps = decay_steps

    def step(self):
        
        # Update Model Step
        self.model_step += 1
        s = self.model_step + 1

        # Update LR
        lr_value = self.lr_values[0]
        for i, step in enumerate(self.decay_steps):
            if self.model_step > step:
                lr_value = self.lr_values[i + 1]
            else:
                break
        self.optimizer.param_groups[0]['lr'] = lr_value

class cosine_annealing_learning_rate_scheduler:

    def __init__(self, optimizer, warmup_steps, lr_max, lr_min, end_step):

        # Model Optimizer
        self.optimizer = optimizer

        # Model Step
        self.model_step = -1

        # Scheduler Params
        self.warmup_steps = warmup_steps
        self.lr_max = lr_max
        self.lr_min = lr_min
        self.end_step = end_step

    def step(self):
        
        # Update Model Step
        self.model_step += 1
        s = self.model_step + 1

        # Compute LR
        if self.model_step <= self.warmup_steps: # Warmup phase
            lr = s / self.warmup_steps * self.lr_max
        else: # Annealing phase
            lr = (self.lr_max - self.lr_min) * 0.5 * (1 + math.cos(math.pi * (self.model_step - self.warmup_steps) / (self.end_step - self.warmup_steps))) + self.lr_min

        # Update LR
        self.optimizer.param_groups[0]['lr'] = lr

class transformer_learning_rate_scheduler:

    def __init__(self, optimizer, dim_model, warmup_steps, K):

        # Model Optimizer
        self.optimizer = optimizer

        # Model Step
        self.model_step = -1

        # Scheduler Params
        self.dim_model = dim_model
        self.warmup_steps = warmup_steps
        self.K = K

    def step(self):
        
        # Update Model Step
        self.model_step += 1
        s = self.model_step + 1

        # Update LR
        arg1 = s**-0.5
        arg2 = s * (self.warmup_steps**-1.5)
        self.optimizer.param_groups[0]['lr'] = self.K * self.dim_model**-0.5 * min(arg1, arg2)

class exponential_decay_transformer_learning_rate_scheduler:

    def __init__(self, optimizer, warmup_steps, lr_max, alpha, end_step):

        # Model Optimizer
        self.optimizer = optimizer

        # Model Step
        self.model_step = -1

        # Scheduler Params
        self.warmup_steps = warmup_steps
        self.lr_max = lr_max
        self.alpha = alpha
        self.end_step = end_step

    def step(self):
        
        # Update Model Step
        self.model_step += 1
        s = self.model_step + 1

        # Update LR
        arg1 = s / self.warmup_steps * self.lr_max # Warmup phase
        arg2 = self.lr_max * self.alpha**((s - self.warmup_steps) / (self.end_step - self.warmup_steps)) # Decay phase
        self.optimizer.param_groups[0]['lr'] = min(arg1, arg2)