optimization.py 8.63 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""

thomwolf's avatar
thomwolf committed
17
import logging
18
import math
thomwolf's avatar
thomwolf committed
19

20
21
import torch
from torch.optim import Optimizer
thomwolf's avatar
thomwolf committed
22
from torch.optim.lr_scheduler import LambdaLR
lukovnikov's avatar
lukovnikov committed
23
24

logger = logging.getLogger(__name__)
25

thomwolf's avatar
thomwolf committed
26
27
class ConstantLRSchedule(LambdaLR):
    def __init__(self, optimizer, last_epoch=-1):
thomwolf's avatar
thomwolf committed
28
        super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
lukovnikov's avatar
lukovnikov committed
29

thomwolf's avatar
thomwolf committed
30
class WarmupCosineSchedule(LambdaLR):
lukovnikov's avatar
lukovnikov committed
31
    """
thomwolf's avatar
thomwolf committed
32
33
    Linearly increases learning rate from 0 to 1 over `warmup` training steps.
    Decreases learning rate from 1. to 0. over remaining `t_total - warmup` steps following a cosine curve.
lukovnikov's avatar
lukovnikov committed
34
    If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
thomwolf's avatar
thomwolf committed
35
36
37
38
    :param warmup:      see LRSchedule
    :param t_total:     see LRSchedule
    :param cycles:      number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
    :param kw:
lukovnikov's avatar
lukovnikov committed
39
    """
lukovnikov's avatar
lukovnikov committed
40
    warn_t_total = True
thomwolf's avatar
thomwolf committed
41
    def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
lukovnikov's avatar
lukovnikov committed
42

thomwolf's avatar
thomwolf committed
43
44
        def lr_lambda(step):
            if step < warmup_steps:
thomwolf's avatar
thomwolf committed
45
                return float(step) / float(max(1.0, warmup_steps))
thomwolf's avatar
thomwolf committed
46
            else:
thomwolf's avatar
thomwolf committed
47
48
                progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps))   # progress after warmup
                return 0.5 * (1. + math.cos(math.pi * float(cycles) * 2.0 * progress))
lukovnikov's avatar
lukovnikov committed
49

thomwolf's avatar
thomwolf committed
50
        super(WarmupCosineSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
lukovnikov's avatar
lukovnikov committed
51

thomwolf's avatar
thomwolf committed
52
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
lukovnikov's avatar
lukovnikov committed
53
    """
lukovnikov's avatar
lukovnikov committed
54
55
56
    Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
    If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
    learning rate (with hard restarts).
lukovnikov's avatar
lukovnikov committed
57
    """
thomwolf's avatar
thomwolf committed
58
    def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
lukovnikov's avatar
lukovnikov committed
59

thomwolf's avatar
thomwolf committed
60
61
        def lr_lambda(step):
            if step < warmup_steps:
thomwolf's avatar
thomwolf committed
62
                return float(step) / float(max(1, warmup_steps))
thomwolf's avatar
thomwolf committed
63
            else:
thomwolf's avatar
thomwolf committed
64
65
66
67
                progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps))   # progress after warmup
                if progress >= 1.0:
                    return 0.0
                return 0.5 * (1. + math.cos(math.pi * ((float(cycles) * progress) % 1.0)))
lukovnikov's avatar
lukovnikov committed
68

thomwolf's avatar
thomwolf committed
69
        super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
lukovnikov's avatar
lukovnikov committed
70

thomwolf's avatar
thomwolf committed
71
72

class WarmupConstantSchedule(LambdaLR):
lukovnikov's avatar
lukovnikov committed
73
    """
lukovnikov's avatar
lukovnikov committed
74
75
    Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
    Keeps learning rate equal to 1. after warmup.
lukovnikov's avatar
lukovnikov committed
76
    """
thomwolf's avatar
thomwolf committed
77
78
79
80
    def __init__(self, optimizer, warmup_steps, last_epoch=-1):

        def lr_lambda(step):
            if step < warmup_steps:
thomwolf's avatar
thomwolf committed
81
                return float(step) / float(max(1.0, warmup_steps))
thomwolf's avatar
thomwolf committed
82
            return 1.
lukovnikov's avatar
lukovnikov committed
83

thomwolf's avatar
thomwolf committed
84
        super(WarmupConstantSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
lukovnikov's avatar
lukovnikov committed
85

thomwolf's avatar
thomwolf committed
86
87

class WarmupLinearSchedule(LambdaLR):
lukovnikov's avatar
lukovnikov committed
88
    """
lukovnikov's avatar
lukovnikov committed
89
90
    Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
    Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
lukovnikov's avatar
lukovnikov committed
91
    """
thomwolf's avatar
thomwolf committed
92
    def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
lukovnikov's avatar
lukovnikov committed
93

thomwolf's avatar
thomwolf committed
94
95
        def lr_lambda(step):
            if step < warmup_steps:
thomwolf's avatar
thomwolf committed
96
97
                return float(step) / float(max(1, warmup_steps))
            return float(t_total - step) / float(max(1.0, t_total - warmup_steps))
98

thomwolf's avatar
thomwolf committed
99
        super(WarmupLinearSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
100
101


thomwolf's avatar
thomwolf committed
102
103
class AdamW(Optimizer):
    """ Implements Adam algorithm with weight decay fix.
104
105

    Parameters:
thomwolf's avatar
thomwolf committed
106
107
108
        lr: learning rate
        warmup: portion of t_total for the warmup, -1  means no warmup. Default: -1
        t_total: total number of training steps for the learning
lukovnikov's avatar
lukovnikov committed
109
            rate schedule, -1  means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
lukovnikov's avatar
lukovnikov committed
110
        schedule: schedule to use for the warmup (see above).
lukovnikov's avatar
lukovnikov committed
111
112
113
            Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
            If `None` or `'none'`, learning rate is always kept constant.
            Default : `'warmup_linear'`
thomwolf's avatar
thomwolf committed
114
115
116
        b1: Adams b1. Default: 0.9
        b2: Adams b2. Default: 0.999
        e: Adams epsilon. Default: 1e-6
117
        weight_decay: Weight decay. Default: 0.01
thomwolf's avatar
thomwolf committed
118
        max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
thomwolf's avatar
thomwolf committed
119
        correct_bias: can be set to False to avoid correcting bias in Adam (e.g. like in Bert repository)
120
    """
thomwolf's avatar
thomwolf committed
121
122
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, correct_bias=True):
        if lr < 0.0:
thomwolf's avatar
thomwolf committed
123
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
thomwolf's avatar
thomwolf committed
124
125
126
127
128
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1]  < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1] ))
        if not 0.0 <= eps:
thomwolf's avatar
thomwolf committed
129
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
thomwolf's avatar
thomwolf committed
130
131
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                        correct_bias=correct_bias)
thomwolf's avatar
thomwolf committed
132
        super(AdamW, self).__init__(params, defaults)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
thomwolf's avatar
thomwolf committed
159
                    state['exp_avg'] = torch.zeros_like(p.data)
160
                    # Exponential moving average of squared gradient values
thomwolf's avatar
thomwolf committed
161
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
162

thomwolf's avatar
thomwolf committed
163
164
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
165

thomwolf's avatar
thomwolf committed
166
                state['step'] += 1
167
168

                # Decay the first and second moment running average coefficient
thomwolf's avatar
thomwolf committed
169
                # In-place operations to update the averages at the same time
thomwolf's avatar
thomwolf committed
170
171
                exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
thomwolf's avatar
thomwolf committed
172
173
174
175
                denom = exp_avg_sq.sqrt().add_(group['eps'])

                step_size = group['lr']
                if group['correct_bias']:  # No bias correction for Bert
thomwolf's avatar
thomwolf committed
176
177
                    bias_correction1 = 1.0 - beta1 ** state['step']
                    bias_correction2 = 1.0 - beta2 ** state['step']
thomwolf's avatar
thomwolf committed
178
179
180
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)
181
182
183
184
185

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
thomwolf's avatar
thomwolf committed
186
                # Instead we want to decay the weights in a manner that doesn't interact
187
188
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
thomwolf's avatar
thomwolf committed
189
                # Add weight decay at the end (fixed version)
thomwolf's avatar
thomwolf committed
190
                if group['weight_decay'] > 0.0:
thomwolf's avatar
thomwolf committed
191
                    p.data.add_(-group['lr'] * group['weight_decay'], p.data)
192
193

        return loss