optimizer_param_scheduler.py 8.62 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
16
"""Learning rate decay and weight decay incr functions."""
Mohammad's avatar
Mohammad committed
17

Raul Puri's avatar
Raul Puri committed
18
19
import math

20
from megatron import print_rank_0
21

22
class OptimizerParamScheduler(object):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
23
    """Anneals learning rate and weight decay"""
Raul Puri's avatar
Raul Puri committed
24

25
    def __init__(self, optimizer, max_lr, min_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
27
                 lr_warmup_steps, lr_decay_steps, lr_decay_style,
                 start_wd, end_wd, wd_incr_steps, wd_incr_style,
28
29
                 use_checkpoint_opt_param_scheduler=True,
                 override_opt_param_scheduler=False):
Mohammad's avatar
Mohammad committed
30
31

        # Class values.
Raul Puri's avatar
Raul Puri committed
32
        self.optimizer = optimizer
mohammad's avatar
mohammad committed
33

34
        self.max_lr = float(max_lr)
35
        self.min_lr = min_lr
mohammad's avatar
mohammad committed
36
        assert self.min_lr >= 0.0
37
        assert self.max_lr >= self.min_lr
mohammad's avatar
mohammad committed
38

Vijay Korthikanti's avatar
Vijay Korthikanti committed
39
        self.lr_warmup_steps = lr_warmup_steps
40
        self.num_steps = 0
Vijay Korthikanti's avatar
Vijay Korthikanti committed
41
42
43
        self.lr_decay_steps = lr_decay_steps
        assert self.lr_decay_steps > 0
        assert self.lr_warmup_steps < self.lr_decay_steps
mohammad's avatar
mohammad committed
44

Vijay Korthikanti's avatar
Vijay Korthikanti committed
45
        self.lr_decay_style = lr_decay_style
mohammad's avatar
mohammad committed
46

47
48
49
50
        self.start_wd = start_wd
        self.end_wd = end_wd
        assert self.start_wd >= 0.0
        assert self.end_wd >= self.start_wd
Vijay Korthikanti's avatar
Vijay Korthikanti committed
51
        self.wd_incr_steps = wd_incr_steps
52
53
        self.wd_incr_style = wd_incr_style

54
55
56
57
        self.override_opt_param_scheduler = override_opt_param_scheduler
        self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
        if self.override_opt_param_scheduler:
            assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
58
                'use-checkpoint are set.'
mohammad's avatar
mohammad committed
59

Mohammad's avatar
Mohammad committed
60
        # Set the learning rate
61
        self.step(0)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
62
        print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
Mohammad's avatar
Mohammad committed
63

mohammad's avatar
mohammad committed
64

65
    def get_wd(self):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
66
67
        """ Weight decay incr functions"""
        if self.num_steps > self.wd_incr_steps:
68
69
70
71
72
73
            return self.end_wd

        if self.wd_incr_style == 'constant':
            assert self.start_wd == self.end_wd
            return self.end_wd

Vijay Korthikanti's avatar
Vijay Korthikanti committed
74
75
76
        incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
        assert incr_ratio >= 0.0
        assert incr_ratio <= 1.0
77
78
79
        delta_wd = self.end_wd - self.start_wd

        if self.wd_incr_style == 'linear':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
80
            coeff = incr_ratio
81
        elif self.wd_incr_style == 'cosine':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
82
            coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
83
84
85
86
87
88
89
        else:
            raise Exception('{} weight decay increment style is not supported.'.format(
                self.wd_incr_style))

        return self.start_wd + coeff * delta_wd


Raul Puri's avatar
Raul Puri committed
90
    def get_lr(self):
Mohammad's avatar
Mohammad committed
91
92
93
        """Learning rate decay functions from:
              https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

mohammad's avatar
mohammad committed
94
        # Use linear warmup for the initial part.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
95
        if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
96
            return self.max_lr * float(self.num_steps) / \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
97
                float(self.lr_warmup_steps)
mohammad's avatar
mohammad committed
98
99

        # If the learning rate is constant, just return the initial value.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
100
        if self.lr_decay_style == 'constant':
101
            return self.max_lr
mohammad's avatar
mohammad committed
102

Vijay Korthikanti's avatar
Vijay Korthikanti committed
103
104
        # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
        if self.num_steps > self.lr_decay_steps:
mohammad's avatar
mohammad committed
105
106
107
            return self.min_lr
        
        # If we are done with the warmup period, use the decay style.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
108
109
        num_steps_ = self.num_steps - self.lr_warmup_steps
        decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
110
        decay_ratio = float(num_steps_) / float(decay_steps_)
mohammad's avatar
mohammad committed
111
112
        assert decay_ratio >= 0.0
        assert decay_ratio <= 1.0
113
        delta_lr = self.max_lr - self.min_lr
Mohammad's avatar
Mohammad committed
114

Vijay Korthikanti's avatar
Vijay Korthikanti committed
115
        if self.lr_decay_style == 'linear':
mohammad's avatar
mohammad committed
116
            coeff = (1.0 - decay_ratio)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
117
        elif self.lr_decay_style == 'cosine':
mohammad's avatar
mohammad committed
118
            coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
Raul Puri's avatar
Raul Puri committed
119
        else:
mohammad's avatar
mohammad committed
120
            raise Exception('{} decay style is not supported.'.format(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
121
                self.lr_decay_style))
Mostofa Patwary's avatar
Mostofa Patwary committed
122

mohammad's avatar
mohammad committed
123
124
        return self.min_lr + coeff * delta_lr

Mohammad's avatar
Mohammad committed
125

126
    def step(self, increment):
Mohammad's avatar
Mohammad committed
127
        """Set lr for all parameters groups."""
128
        self.num_steps += increment
Raul Puri's avatar
Raul Puri committed
129
        new_lr = self.get_lr()
130
        new_wd = self.get_wd()
Raul Puri's avatar
Raul Puri committed
131
        for group in self.optimizer.param_groups:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
132
133
            group['lr'] = new_lr * group.get('lr_mult', 1.0)
            group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
Raul Puri's avatar
Raul Puri committed
134

mohammad's avatar
mohammad committed
135

Raul Puri's avatar
Raul Puri committed
136
    def state_dict(self):
Mohammad's avatar
Mohammad committed
137
        state_dict = {
138
            'max_lr': self.max_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
139
            'lr_warmup_steps': self.lr_warmup_steps,
140
            'num_steps': self.num_steps,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
141
142
143
144
145
146
147
            'lr_decay_style': self.lr_decay_style,
            'lr_decay_steps': self.lr_decay_steps,
            'min_lr': self.min_lr,
            'start_wd': self.start_wd,
            'end_wd': self.end_wd,
            'wd_incr_style': self.wd_incr_style,
            'wd_incr_steps': self.wd_incr_steps
Raul Puri's avatar
Raul Puri committed
148
        }
Mohammad's avatar
Mohammad committed
149
        return state_dict
Raul Puri's avatar
Raul Puri committed
150

mohammad's avatar
mohammad committed
151

Mohammad's avatar
Mohammad committed
152
153
154
    def _check_and_set(self, cls_value, sd_value, name):
        """Auxiliary function for checking the values in the checkpoint and
        setting them."""
155
        if self.override_opt_param_scheduler:
156
157
            print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
            return cls_value
Mohammad's avatar
Mohammad committed
158

159
        if not self.use_checkpoint_opt_param_scheduler:
160
            assert cls_value == sd_value, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
161
                f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
162
                f'value {sd_value} for {name} do not match'
Mohammad's avatar
Mohammad committed
163
164
165
166
        print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
                                                                  name))
        return sd_value

mohammad's avatar
mohammad committed
167

Raul Puri's avatar
Raul Puri committed
168
    def load_state_dict(self, sd):
169

170
171
172
173
174
175
176
        if 'start_lr' in sd:
            max_lr_ = sd['start_lr']
        else:
            max_lr_ = sd['max_lr']
        self.max_lr = self._check_and_set(self.max_lr, max_lr_,
                                          'learning rate')
        
Mohammad's avatar
Mohammad committed
177
        self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
178
                                          'minimum learning rate')
179
180

        if 'warmup_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
181
            lr_warmup_steps_ = sd['warmup_iter']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
182
        elif 'warmup_steps' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
183
            lr_warmup_steps_ = sd['warmup_steps']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184
185
186
187
        else:
            lr_warmup_steps_ = sd['lr_warmup_steps']
        self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
                                                lr_warmup_steps_,
188
189
190
                                                'warmup iterations')

        if 'end_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
191
192
193
            lr_decay_steps_ = sd['end_iter']
        elif 'decay_steps' in sd:
            lr_decay_steps_  = sd['decay_steps']
194
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
195
196
            lr_decay_steps_ = sd['lr_decay_steps']
        self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
197
                                               'total number of iterations')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
198
199
200
201
202
203
204
205

        if 'decay_style' in sd:
            lr_decay_style_ = sd['decay_style']
        else:
            lr_decay_style_ = sd['lr_decay_style']
        self.lr_decay_style = self._check_and_set(self.lr_decay_style,
                                               lr_decay_style_,
                                               'learning rate decay style')
206

207
        if 'num_iters' in sd:
208
            num_steps = sd['num_iters']
209
        else:
210
211
            num_steps = sd['num_steps']
        self.step(increment=num_steps)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234


        if 'start_wd' in sd:
            self.start_wd = self._check_and_set(self.start_wd,
                                                sd['start_wd'],
                                                "start weight decay")
            self.end_wd = self._check_and_set(self.end_wd,
                                                sd['end_wd'],
                                                "end weight decay")
            self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
                                                sd['wd_incr_steps'],
                                                "total number of weight decay iterations")
            self.wd_incr_style = self._check_and_set(self.wd_incr_style,
                                                sd['wd_incr_style'],
                                                "weight decay incr style")