optimizer_param_scheduler.py 8.07 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Vijay Korthikanti's avatar
Vijay Korthikanti committed
3
"""Learning rate decay and weight decay incr functions."""
Mohammad's avatar
Mohammad committed
4

Raul Puri's avatar
Raul Puri committed
5
6
import math

7
from megatron import print_rank_0
8

9
class OptimizerParamScheduler(object):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
10
    """Anneals learning rate and weight decay"""
Raul Puri's avatar
Raul Puri committed
11

12
    def __init__(self, optimizer, max_lr, min_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
13
14
                 lr_warmup_steps, lr_decay_steps, lr_decay_style,
                 start_wd, end_wd, wd_incr_steps, wd_incr_style,
15
16
                 use_checkpoint_opt_param_scheduler=True,
                 override_opt_param_scheduler=False):
Mohammad's avatar
Mohammad committed
17
18

        # Class values.
Raul Puri's avatar
Raul Puri committed
19
        self.optimizer = optimizer
mohammad's avatar
mohammad committed
20

21
        self.max_lr = float(max_lr)
22
        self.min_lr = min_lr
mohammad's avatar
mohammad committed
23
        assert self.min_lr >= 0.0
24
        assert self.max_lr >= self.min_lr
mohammad's avatar
mohammad committed
25

Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
        self.lr_warmup_steps = lr_warmup_steps
27
        self.num_steps = 0
Vijay Korthikanti's avatar
Vijay Korthikanti committed
28
29
30
        self.lr_decay_steps = lr_decay_steps
        assert self.lr_decay_steps > 0
        assert self.lr_warmup_steps < self.lr_decay_steps
mohammad's avatar
mohammad committed
31

Vijay Korthikanti's avatar
Vijay Korthikanti committed
32
        self.lr_decay_style = lr_decay_style
mohammad's avatar
mohammad committed
33

34
35
36
37
        self.start_wd = start_wd
        self.end_wd = end_wd
        assert self.start_wd >= 0.0
        assert self.end_wd >= self.start_wd
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38
        self.wd_incr_steps = wd_incr_steps
39
40
        self.wd_incr_style = wd_incr_style

41
42
43
44
        self.override_opt_param_scheduler = override_opt_param_scheduler
        self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
        if self.override_opt_param_scheduler:
            assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
45
                'use-checkpoint are set.'
mohammad's avatar
mohammad committed
46

Mohammad's avatar
Mohammad committed
47
        # Set the learning rate
48
        self.step(0)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
        print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
Mohammad's avatar
Mohammad committed
50

mohammad's avatar
mohammad committed
51

52
    def get_wd(self):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
53
54
        """ Weight decay incr functions"""
        if self.num_steps > self.wd_incr_steps:
55
56
57
58
59
60
            return self.end_wd

        if self.wd_incr_style == 'constant':
            assert self.start_wd == self.end_wd
            return self.end_wd

Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
62
63
        incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
        assert incr_ratio >= 0.0
        assert incr_ratio <= 1.0
64
65
66
        delta_wd = self.end_wd - self.start_wd

        if self.wd_incr_style == 'linear':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
67
            coeff = incr_ratio
68
        elif self.wd_incr_style == 'cosine':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
            coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
70
71
72
73
74
75
76
        else:
            raise Exception('{} weight decay increment style is not supported.'.format(
                self.wd_incr_style))

        return self.start_wd + coeff * delta_wd


Raul Puri's avatar
Raul Puri committed
77
    def get_lr(self):
Mohammad's avatar
Mohammad committed
78
79
80
        """Learning rate decay functions from:
              https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

mohammad's avatar
mohammad committed
81
        # Use linear warmup for the initial part.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
82
        if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
83
            return self.max_lr * float(self.num_steps) / \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
84
                float(self.lr_warmup_steps)
mohammad's avatar
mohammad committed
85
86

        # If the learning rate is constant, just return the initial value.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
87
        if self.lr_decay_style == 'constant':
88
            return self.max_lr
mohammad's avatar
mohammad committed
89

Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
91
        # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
        if self.num_steps > self.lr_decay_steps:
mohammad's avatar
mohammad committed
92
93
94
            return self.min_lr
        
        # If we are done with the warmup period, use the decay style.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
95
96
        num_steps_ = self.num_steps - self.lr_warmup_steps
        decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
97
        decay_ratio = float(num_steps_) / float(decay_steps_)
mohammad's avatar
mohammad committed
98
99
        assert decay_ratio >= 0.0
        assert decay_ratio <= 1.0
100
        delta_lr = self.max_lr - self.min_lr
Mohammad's avatar
Mohammad committed
101

Vijay Korthikanti's avatar
Vijay Korthikanti committed
102
        if self.lr_decay_style == 'linear':
mohammad's avatar
mohammad committed
103
            coeff = (1.0 - decay_ratio)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
104
        elif self.lr_decay_style == 'cosine':
mohammad's avatar
mohammad committed
105
            coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
Raul Puri's avatar
Raul Puri committed
106
        else:
mohammad's avatar
mohammad committed
107
            raise Exception('{} decay style is not supported.'.format(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
108
                self.lr_decay_style))
Mostofa Patwary's avatar
Mostofa Patwary committed
109

mohammad's avatar
mohammad committed
110
111
        return self.min_lr + coeff * delta_lr

Mohammad's avatar
Mohammad committed
112

113
    def step(self, increment):
Mohammad's avatar
Mohammad committed
114
        """Set lr for all parameters groups."""
115
        self.num_steps += increment
Raul Puri's avatar
Raul Puri committed
116
        new_lr = self.get_lr()
117
        new_wd = self.get_wd()
Raul Puri's avatar
Raul Puri committed
118
        for group in self.optimizer.param_groups:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
119
120
            group['lr'] = new_lr * group.get('lr_mult', 1.0)
            group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
Raul Puri's avatar
Raul Puri committed
121

mohammad's avatar
mohammad committed
122

Raul Puri's avatar
Raul Puri committed
123
    def state_dict(self):
Mohammad's avatar
Mohammad committed
124
        state_dict = {
125
            'max_lr': self.max_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
126
            'lr_warmup_steps': self.lr_warmup_steps,
127
            'num_steps': self.num_steps,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
128
129
130
131
132
133
134
            'lr_decay_style': self.lr_decay_style,
            'lr_decay_steps': self.lr_decay_steps,
            'min_lr': self.min_lr,
            'start_wd': self.start_wd,
            'end_wd': self.end_wd,
            'wd_incr_style': self.wd_incr_style,
            'wd_incr_steps': self.wd_incr_steps
Raul Puri's avatar
Raul Puri committed
135
        }
Mohammad's avatar
Mohammad committed
136
        return state_dict
Raul Puri's avatar
Raul Puri committed
137

mohammad's avatar
mohammad committed
138

Mohammad's avatar
Mohammad committed
139
140
141
    def _check_and_set(self, cls_value, sd_value, name):
        """Auxiliary function for checking the values in the checkpoint and
        setting them."""
142
        if self.override_opt_param_scheduler:
143
144
            print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
            return cls_value
Mohammad's avatar
Mohammad committed
145

146
        if not self.use_checkpoint_opt_param_scheduler:
147
            assert cls_value == sd_value, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
148
                f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
149
                f'value {sd_value} for {name} do not match'
Mohammad's avatar
Mohammad committed
150
151
152
153
        print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
                                                                  name))
        return sd_value

mohammad's avatar
mohammad committed
154

Raul Puri's avatar
Raul Puri committed
155
    def load_state_dict(self, sd):
156

157
158
159
160
161
162
163
        if 'start_lr' in sd:
            max_lr_ = sd['start_lr']
        else:
            max_lr_ = sd['max_lr']
        self.max_lr = self._check_and_set(self.max_lr, max_lr_,
                                          'learning rate')
        
Mohammad's avatar
Mohammad committed
164
        self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
165
                                          'minimum learning rate')
166
167

        if 'warmup_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
168
            lr_warmup_steps_ = sd['warmup_iter']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
169
        elif 'warmup_steps' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
170
            lr_warmup_steps_ = sd['warmup_steps']
Vijay Korthikanti's avatar
Vijay Korthikanti committed
171
172
173
174
        else:
            lr_warmup_steps_ = sd['lr_warmup_steps']
        self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
                                                lr_warmup_steps_,
175
176
177
                                                'warmup iterations')

        if 'end_iter' in sd:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
178
179
180
            lr_decay_steps_ = sd['end_iter']
        elif 'decay_steps' in sd:
            lr_decay_steps_  = sd['decay_steps']
181
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
182
183
            lr_decay_steps_ = sd['lr_decay_steps']
        self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
184
                                               'total number of iterations')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
185
186
187
188
189
190
191
192

        if 'decay_style' in sd:
            lr_decay_style_ = sd['decay_style']
        else:
            lr_decay_style_ = sd['lr_decay_style']
        self.lr_decay_style = self._check_and_set(self.lr_decay_style,
                                               lr_decay_style_,
                                               'learning rate decay style')
193

194
        if 'num_iters' in sd:
195
            num_steps = sd['num_iters']
196
        else:
197
198
            num_steps = sd['num_steps']
        self.step(increment=num_steps)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221


        if 'start_wd' in sd:
            self.start_wd = self._check_and_set(self.start_wd,
                                                sd['start_wd'],
                                                "start weight decay")
            self.end_wd = self._check_and_set(self.end_wd,
                                                sd['end_wd'],
                                                "end weight decay")
            self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
                                                sd['wd_incr_steps'],
                                                "total number of weight decay iterations")
            self.wd_incr_style = self._check_and_set(self.wd_incr_style,
                                                sd['wd_incr_style'],
                                                "weight decay incr style")