optimizer_param_scheduler.py 8.33 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Vijay Korthikanti's avatar
Vijay Korthikanti committed
3
"""Learning rate decay and weight decay incr functions."""
Mohammad's avatar
Mohammad committed
4

Raul Puri's avatar
Raul Puri committed
5
6
import math

7
from megatron import print_rank_0
8

9
class OptimizerParamScheduler(object):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
10
    """Anneals learning rate and weight decay"""
Raul Puri's avatar
Raul Puri committed
11

12
    def __init__(self, optimizer, max_lr, min_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
13
14
                 lr_warmup_steps, lr_decay_steps, lr_decay_style,
                 start_wd, end_wd, wd_incr_steps, wd_incr_style,
15
16
                 use_checkpoint_opt_param_scheduler=True,
                 override_opt_param_scheduler=False):
Mohammad's avatar
Mohammad committed
17
18

        # Class values.
Raul Puri's avatar
Raul Puri committed
19
        self.optimizer = optimizer
mohammad's avatar
mohammad committed
20

21
        self.max_lr = float(max_lr)
22
        self.min_lr = min_lr
mohammad's avatar
mohammad committed
23
        assert self.min_lr >= 0.0
24
        assert self.max_lr >= self.min_lr
mohammad's avatar
mohammad committed
25

Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
        self.lr_warmup_steps = lr_warmup_steps
27
        self.num_steps = 0
Vijay Korthikanti's avatar
Vijay Korthikanti committed
28
29
30
        self.lr_decay_steps = lr_decay_steps
        assert self.lr_decay_steps > 0
        assert self.lr_warmup_steps < self.lr_decay_steps
mohammad's avatar
mohammad committed
31

Vijay Korthikanti's avatar
Vijay Korthikanti committed
32
        self.lr_decay_style = lr_decay_style
mohammad's avatar
mohammad committed
33

34
35
36
37
        self.start_wd = start_wd
        self.end_wd = end_wd
        assert self.start_wd >= 0.0
        assert self.end_wd >= self.start_wd
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38
        self.wd_incr_steps = wd_incr_steps
39
40
        self.wd_incr_style = wd_incr_style

41
42
43
44
        self.override_opt_param_scheduler = override_opt_param_scheduler
        self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
        if self.override_opt_param_scheduler:
            assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
45
                'use-checkpoint are set.'
mohammad's avatar
mohammad committed
46

Mohammad's avatar
Mohammad committed
47
        # Set the learning rate
48
        self.step(0)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
        print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
Mohammad's avatar
Mohammad committed
50

mohammad's avatar
mohammad committed
51

52
    def get_wd(self):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
53
54
        """ Weight decay incr functions"""
        if self.num_steps > self.wd_incr_steps:
55
56
57
58
59
60
            return self.end_wd

        if self.wd_incr_style == 'constant':
            assert self.start_wd == self.end_wd
            return self.end_wd

Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
62
63
        incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
        assert incr_ratio >= 0.0
        assert incr_ratio <= 1.0
64
65
66
        delta_wd = self.end_wd - self.start_wd

        if self.wd_incr_style == 'linear':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
67
            coeff = incr_ratio
68
        elif self.wd_incr_style == 'cosine':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
            coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
70
71
72
73
74
75
76
        else:
            raise Exception('{} weight decay increment style is not supported.'.format(
                self.wd_incr_style))

        return self.start_wd + coeff * delta_wd


Raul Puri's avatar
Raul Puri committed
77
    def get_lr(self):
Mohammad's avatar
Mohammad committed
78
79
80
        """Learning rate decay functions from:
              https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

mohammad's avatar
mohammad committed
81
        # Use linear warmup for the initial part.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
82
        if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
83
            return self.max_lr * float(self.num_steps) / \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
84
                float(self.lr_warmup_steps)
mohammad's avatar
mohammad committed
85
86

        # If the learning rate is constant, just return the initial value.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
87
        if self.lr_decay_style == 'constant':
88
            return self.max_lr
mohammad's avatar
mohammad committed
89

Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
91
        # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
        if self.num_steps > self.lr_decay_steps:
mohammad's avatar
mohammad committed
92
            return self.min_lr
93

mohammad's avatar
mohammad committed
94
        # If we are done with the warmup period, use the decay style.
95
96
97
98
99
100
        if self.lr_decay_style == 'inverse-square-root':
            warmup_steps = max(self.lr_warmup_steps, 1)
            num_steps = max(self.num_steps, 1)
            lr = self.max_lr * warmup_steps ** 0.5 / (num_steps ** 0.5)
            return max(self.min_lr, lr)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
101
102
        num_steps_ = self.num_steps - self.lr_warmup_steps
        decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
103
        decay_ratio = float(num_steps_) / float(decay_steps_)
mohammad's avatar
mohammad committed
104
105
        assert decay_ratio >= 0.0
        assert decay_ratio <= 1.0
106
        delta_lr = self.max_lr - self.min_lr
Mohammad's avatar
Mohammad committed
107

Vijay Korthikanti's avatar
Vijay Korthikanti committed
108
        if self.lr_decay_style == 'linear':
mohammad's avatar
mohammad committed
109
            coeff = (1.0 - decay_ratio)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
110
        elif self.lr_decay_style == 'cosine':
mohammad's avatar
mohammad committed
111
            coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
Raul Puri's avatar
Raul Puri committed
112
        else:
mohammad's avatar
mohammad committed
113
            raise Exception('{} decay style is not supported.'.format(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
114
                self.lr_decay_style))
Mostofa Patwary's avatar
Mostofa Patwary committed
115

mohammad's avatar
mohammad committed
116
117
        return self.min_lr + coeff * delta_lr

Mohammad's avatar
Mohammad committed
118

119
    def step(self, increment):
Mohammad's avatar
Mohammad committed
120
        """Set lr for all parameters groups."""
121
        self.num_steps += increment
Raul Puri's avatar
Raul Puri committed
122
        new_lr = self.get_lr()
123
        new_wd = self.get_wd()
Raul Puri's avatar
Raul Puri committed
124
        for group in self.optimizer.param_groups:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
125
126
            group['lr'] = new_lr * group.get('lr_mult', 1.0)
            group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
Raul Puri's avatar
Raul Puri committed
127

mohammad's avatar
mohammad committed
128

Raul Puri's avatar
Raul Puri committed
129
    def state_dict(self):
Mohammad's avatar
Mohammad committed
130
        state_dict = {
131
            'max_lr': self.max_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
132
            'lr_warmup_steps': self.lr_warmup_steps,
133
            'num_steps': self.num_steps,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
134
135
136
137
138
139
140
            'lr_decay_style': self.lr_decay_style,
            'lr_decay_steps': self.lr_decay_steps,
            'min_lr': self.min_lr,
            'start_wd': self.start_wd,
            'end_wd': self.end_wd,
            'wd_incr_style': self.wd_incr_style,
            'wd_incr_steps': self.wd_incr_steps
Raul Puri's avatar
Raul Puri committed
141
        }
Mohammad's avatar
Mohammad committed
142
        return state_dict
Raul Puri's avatar
Raul Puri committed
143

mohammad's avatar
mohammad committed
144

Mohammad's avatar
Mohammad committed
145
146
147
    def _check_and_set(self, cls_value, sd_value, name):
        """Auxiliary function for checking the values in the checkpoint and
        setting them."""
148
        if self.override_opt_param_scheduler:
149
150
            print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
            return cls_value
Mohammad's avatar
Mohammad committed
151

152
        if not self.use_checkpoint_opt_param_scheduler:
153
            assert cls_value == sd_value, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
154
                f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
155
                f'value {sd_value} for {name} do not match'
Mohammad's avatar
Mohammad committed
156
157
158
159
        print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
                                                                  name))
        return sd_value

mohammad's avatar
mohammad committed
160

Raul Puri's avatar
Raul Puri committed
161
    def load_state_dict(self, sd):
162

163
164
165
166
167
168
169
        if 'start_lr' in sd:
            max_lr_ = sd['start_lr']
        else:
            max_lr_ = sd['max_lr']
        self.max_lr = self._check_and_set(self.max_lr, max_lr_,
                                          'learning rate')
        
Mohammad's avatar
Mohammad committed
170
        self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
171
                                          'minimum learning rate')
172
173

        if 'warmup_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
174
            lr_warmup_steps_ = sd['warmup_iter']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
175
        elif 'warmup_steps' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
176
            lr_warmup_steps_ = sd['warmup_steps']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
177
178
179
180
        else:
            lr_warmup_steps_ = sd['lr_warmup_steps']
        self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
                                                lr_warmup_steps_,
181
182
183
                                                'warmup iterations')

        if 'end_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184
185
186
            lr_decay_steps_ = sd['end_iter']
        elif 'decay_steps' in sd:
            lr_decay_steps_  = sd['decay_steps']
187
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
188
189
            lr_decay_steps_ = sd['lr_decay_steps']
        self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
190
                                               'total number of iterations')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
191
192
193
194
195
196
197
198

        if 'decay_style' in sd:
            lr_decay_style_ = sd['decay_style']
        else:
            lr_decay_style_ = sd['lr_decay_style']
        self.lr_decay_style = self._check_and_set(self.lr_decay_style,
                                               lr_decay_style_,
                                               'learning rate decay style')
199

200
        if 'num_iters' in sd:
201
            num_steps = sd['num_iters']
202
        else:
203
204
            num_steps = sd['num_steps']
        self.step(increment=num_steps)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227


        if 'start_wd' in sd:
            self.start_wd = self._check_and_set(self.start_wd,
                                                sd['start_wd'],
                                                "start weight decay")
            self.end_wd = self._check_and_set(self.end_wd,
                                                sd['end_wd'],
                                                "end weight decay")
            self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
                                                sd['wd_incr_steps'],
                                                "total number of weight decay iterations")
            self.wd_incr_style = self._check_and_set(self.wd_incr_style,
                                                sd['wd_incr_style'],
                                                "weight decay incr style")