optimizer_param_scheduler.py 8.55 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Vijay Korthikanti's avatar
Vijay Korthikanti committed
3
"""Learning rate decay and weight decay incr functions."""
Mohammad's avatar
Mohammad committed
4

Raul Puri's avatar
Raul Puri committed
5
6
import math

7
from megatron import print_rank_0
8

9
class OptimizerParamScheduler(object):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
10
    """Anneals learning rate and weight decay"""
Raul Puri's avatar
Raul Puri committed
11

liangjing's avatar
v1  
liangjing committed
12
    def __init__(self, optimizer, init_lr, max_lr, min_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
13
14
                 lr_warmup_steps, lr_decay_steps, lr_decay_style,
                 start_wd, end_wd, wd_incr_steps, wd_incr_style,
15
16
                 use_checkpoint_opt_param_scheduler=True,
                 override_opt_param_scheduler=False):
Mohammad's avatar
Mohammad committed
17
18

        # Class values.
Raul Puri's avatar
Raul Puri committed
19
        self.optimizer = optimizer
mohammad's avatar
mohammad committed
20

liangjing's avatar
v1  
liangjing committed
21
        self.init_lr = init_lr
22
        self.max_lr = float(max_lr)
23
        self.min_lr = min_lr
mohammad's avatar
mohammad committed
24
        assert self.min_lr >= 0.0
25
        assert self.max_lr >= self.min_lr
liangjing's avatar
v1  
liangjing committed
26
        assert self.init_lr <= self.max_lr
mohammad's avatar
mohammad committed
27

Vijay Korthikanti's avatar
Vijay Korthikanti committed
28
        self.lr_warmup_steps = lr_warmup_steps
29
        self.num_steps = 0
Vijay Korthikanti's avatar
Vijay Korthikanti committed
30
31
32
        self.lr_decay_steps = lr_decay_steps
        assert self.lr_decay_steps > 0
        assert self.lr_warmup_steps < self.lr_decay_steps
mohammad's avatar
mohammad committed
33

Vijay Korthikanti's avatar
Vijay Korthikanti committed
34
        self.lr_decay_style = lr_decay_style
mohammad's avatar
mohammad committed
35

36
37
38
39
        self.start_wd = start_wd
        self.end_wd = end_wd
        assert self.start_wd >= 0.0
        assert self.end_wd >= self.start_wd
Vijay Korthikanti's avatar
Vijay Korthikanti committed
40
        self.wd_incr_steps = wd_incr_steps
41
42
        self.wd_incr_style = wd_incr_style

43
44
45
46
        self.override_opt_param_scheduler = override_opt_param_scheduler
        self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
        if self.override_opt_param_scheduler:
            assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
47
                'use-checkpoint are set.'
mohammad's avatar
mohammad committed
48

Mohammad's avatar
Mohammad committed
49
        # Set the learning rate
50
        self.step(0)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
51
        print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
Mohammad's avatar
Mohammad committed
52

mohammad's avatar
mohammad committed
53

54
    def get_wd(self):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
55
56
        """ Weight decay incr functions"""
        if self.num_steps > self.wd_incr_steps:
57
58
59
60
61
62
            return self.end_wd

        if self.wd_incr_style == 'constant':
            assert self.start_wd == self.end_wd
            return self.end_wd

Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
64
65
        incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
        assert incr_ratio >= 0.0
        assert incr_ratio <= 1.0
66
67
68
        delta_wd = self.end_wd - self.start_wd

        if self.wd_incr_style == 'linear':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
            coeff = incr_ratio
70
        elif self.wd_incr_style == 'cosine':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
71
            coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
72
73
74
75
76
77
78
        else:
            raise Exception('{} weight decay increment style is not supported.'.format(
                self.wd_incr_style))

        return self.start_wd + coeff * delta_wd


Raul Puri's avatar
Raul Puri committed
79
    def get_lr(self):
Mohammad's avatar
Mohammad committed
80
81
82
        """Learning rate decay functions from:
              https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

mohammad's avatar
mohammad committed
83
        # Use linear warmup for the initial part.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
84
        if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
liangjing's avatar
v1  
liangjing committed
85
86
87
88
89
90
91
92
            return (
                self.init_lr
                + (
                    (self.max_lr - self.init_lr)
                    * float(self.num_steps)
                    / float(self.lr_warmup_steps)
                )
            )
mohammad's avatar
mohammad committed
93
94

        # If the learning rate is constant, just return the initial value.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
95
        if self.lr_decay_style == 'constant':
96
            return self.max_lr
mohammad's avatar
mohammad committed
97

Vijay Korthikanti's avatar
Vijay Korthikanti committed
98
99
        # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
        if self.num_steps > self.lr_decay_steps:
mohammad's avatar
mohammad committed
100
            return self.min_lr
101

mohammad's avatar
mohammad committed
102
        # If we are done with the warmup period, use the decay style.
103
104
105
106
107
108
        if self.lr_decay_style == 'inverse-square-root':
            warmup_steps = max(self.lr_warmup_steps, 1)
            num_steps = max(self.num_steps, 1)
            lr = self.max_lr * warmup_steps ** 0.5 / (num_steps ** 0.5)
            return max(self.min_lr, lr)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
109
110
        num_steps_ = self.num_steps - self.lr_warmup_steps
        decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
111
        decay_ratio = float(num_steps_) / float(decay_steps_)
mohammad's avatar
mohammad committed
112
113
        assert decay_ratio >= 0.0
        assert decay_ratio <= 1.0
114
        delta_lr = self.max_lr - self.min_lr
Mohammad's avatar
Mohammad committed
115

Vijay Korthikanti's avatar
Vijay Korthikanti committed
116
        if self.lr_decay_style == 'linear':
mohammad's avatar
mohammad committed
117
            coeff = (1.0 - decay_ratio)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
118
        elif self.lr_decay_style == 'cosine':
mohammad's avatar
mohammad committed
119
            coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
Raul Puri's avatar
Raul Puri committed
120
        else:
mohammad's avatar
mohammad committed
121
            raise Exception('{} decay style is not supported.'.format(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
122
                self.lr_decay_style))
Mostofa Patwary's avatar
Mostofa Patwary committed
123

mohammad's avatar
mohammad committed
124
125
        return self.min_lr + coeff * delta_lr

Mohammad's avatar
Mohammad committed
126

127
    def step(self, increment):
Mohammad's avatar
Mohammad committed
128
        """Set lr for all parameters groups."""
129
        self.num_steps += increment
Raul Puri's avatar
Raul Puri committed
130
        new_lr = self.get_lr()
131
        new_wd = self.get_wd()
Raul Puri's avatar
Raul Puri committed
132
        for group in self.optimizer.param_groups:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
133
134
            group['lr'] = new_lr * group.get('lr_mult', 1.0)
            group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
Raul Puri's avatar
Raul Puri committed
135

mohammad's avatar
mohammad committed
136

Raul Puri's avatar
Raul Puri committed
137
    def state_dict(self):
Mohammad's avatar
Mohammad committed
138
        state_dict = {
139
            'max_lr': self.max_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
140
            'lr_warmup_steps': self.lr_warmup_steps,
141
            'num_steps': self.num_steps,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
142
143
144
145
146
147
148
            'lr_decay_style': self.lr_decay_style,
            'lr_decay_steps': self.lr_decay_steps,
            'min_lr': self.min_lr,
            'start_wd': self.start_wd,
            'end_wd': self.end_wd,
            'wd_incr_style': self.wd_incr_style,
            'wd_incr_steps': self.wd_incr_steps
Raul Puri's avatar
Raul Puri committed
149
        }
Mohammad's avatar
Mohammad committed
150
        return state_dict
Raul Puri's avatar
Raul Puri committed
151

mohammad's avatar
mohammad committed
152

Mohammad's avatar
Mohammad committed
153
154
155
    def _check_and_set(self, cls_value, sd_value, name):
        """Auxiliary function for checking the values in the checkpoint and
        setting them."""
156
        if self.override_opt_param_scheduler:
157
158
            print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
            return cls_value
Mohammad's avatar
Mohammad committed
159

160
        if not self.use_checkpoint_opt_param_scheduler:
161
            assert cls_value == sd_value, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
162
                f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
163
                f'value {sd_value} for {name} do not match'
Mohammad's avatar
Mohammad committed
164
165
166
167
        print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
                                                                  name))
        return sd_value

mohammad's avatar
mohammad committed
168

Raul Puri's avatar
Raul Puri committed
169
    def load_state_dict(self, sd):
170

171
172
173
174
175
176
177
        if 'start_lr' in sd:
            max_lr_ = sd['start_lr']
        else:
            max_lr_ = sd['max_lr']
        self.max_lr = self._check_and_set(self.max_lr, max_lr_,
                                          'learning rate')
        
Mohammad's avatar
Mohammad committed
178
        self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
179
                                          'minimum learning rate')
180
181

        if 'warmup_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
182
            lr_warmup_steps_ = sd['warmup_iter']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
183
        elif 'warmup_steps' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184
            lr_warmup_steps_ = sd['warmup_steps']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
185
186
187
188
        else:
            lr_warmup_steps_ = sd['lr_warmup_steps']
        self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
                                                lr_warmup_steps_,
189
190
191
                                                'warmup iterations')

        if 'end_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
192
193
194
            lr_decay_steps_ = sd['end_iter']
        elif 'decay_steps' in sd:
            lr_decay_steps_  = sd['decay_steps']
195
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
196
197
            lr_decay_steps_ = sd['lr_decay_steps']
        self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
198
                                               'total number of iterations')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
199
200
201
202
203
204
205
206

        if 'decay_style' in sd:
            lr_decay_style_ = sd['decay_style']
        else:
            lr_decay_style_ = sd['lr_decay_style']
        self.lr_decay_style = self._check_and_set(self.lr_decay_style,
                                               lr_decay_style_,
                                               'learning rate decay style')
207

208
        if 'num_iters' in sd:
209
            num_steps = sd['num_iters']
210
        else:
211
212
            num_steps = sd['num_steps']
        self.step(increment=num_steps)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


        if 'start_wd' in sd:
            self.start_wd = self._check_and_set(self.start_wd,
                                                sd['start_wd'],
                                                "start weight decay")
            self.end_wd = self._check_and_set(self.end_wd,
                                                sd['end_wd'],
                                                "end weight decay")
            self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
                                                sd['wd_incr_steps'],
                                                "total number of weight decay iterations")
            self.wd_incr_style = self._check_and_set(self.wd_incr_style,
                                                sd['wd_incr_style'],
                                                "weight decay incr style")