optimization.py 8.43 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""

thomwolf's avatar
thomwolf committed
17
import logging
18
import math
thomwolf's avatar
thomwolf committed
19

20
21
import torch
from torch.optim import Optimizer
thomwolf's avatar
thomwolf committed
22
from torch.optim.lr_scheduler import LambdaLR
lukovnikov's avatar
lukovnikov committed
23
24

logger = logging.getLogger(__name__)
25

thomwolf's avatar
thomwolf committed
26
class ConstantLRSchedule(LambdaLR):
thomwolf's avatar
thomwolf committed
27
28
    """ Constant learning rate schedule.
    """
thomwolf's avatar
thomwolf committed
29
    def __init__(self, optimizer, last_epoch=-1):
thomwolf's avatar
thomwolf committed
30
        super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
lukovnikov's avatar
lukovnikov committed
31

thomwolf's avatar
thomwolf committed
32
33
34
35
36

class WarmupConstantSchedule(LambdaLR):
    """ Linear warmup and then constant.
        Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
        Keeps learning rate schedule equal to 1. after warmup_steps.
lukovnikov's avatar
lukovnikov committed
37
    """
thomwolf's avatar
thomwolf committed
38
    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
39
40
        self.warmup_steps = warmup_steps
        super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
lukovnikov's avatar
lukovnikov committed
41

42
43
44
45
    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1.0, self.warmup_steps))
        return 1.
lukovnikov's avatar
lukovnikov committed
46

thomwolf's avatar
thomwolf committed
47
48
49
50
51

class WarmupLinearSchedule(LambdaLR):
    """ Linear warmup and then linear decay.
        Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
        Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
lukovnikov's avatar
lukovnikov committed
52
    """
thomwolf's avatar
thomwolf committed
53
    def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
54
55
56
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
lukovnikov's avatar
lukovnikov committed
57

58
59
60
61
    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1, self.warmup_steps))
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
lukovnikov's avatar
lukovnikov committed
62

thomwolf's avatar
thomwolf committed
63

thomwolf's avatar
thomwolf committed
64
65
66
67
68
class WarmupCosineSchedule(LambdaLR):
    """ Linear warmup and then cosine decay.
        Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
        Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
        If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
lukovnikov's avatar
lukovnikov committed
69
    """
thomwolf's avatar
thomwolf committed
70
    def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
71
72
73
74
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        self.cycles = cycles
        super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
thomwolf's avatar
thomwolf committed
75

76
77
78
79
80
81
    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1.0, self.warmup_steps))
        # progress after warmup
        progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
        return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
lukovnikov's avatar
lukovnikov committed
82

thomwolf's avatar
thomwolf committed
83

thomwolf's avatar
thomwolf committed
84
85
86
87
88
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
    """ Linear warmup and then cosine cycles with hard restarts.
        Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
        If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
        learning rate (with hard restarts).
lukovnikov's avatar
lukovnikov committed
89
    """
thomwolf's avatar
thomwolf committed
90
    def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
91
92
93
94
95
96
97
98
99
100
101
102
103
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        self.cycles = cycles
        super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1, self.warmup_steps))
        # progress after warmup
        progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
        if progress >= 1.0:
            return 0.0
        return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
lukovnikov's avatar
lukovnikov committed
104

105
106


thomwolf's avatar
thomwolf committed
107
108
class AdamW(Optimizer):
    """ Implements Adam algorithm with weight decay fix.
109
110

    Parameters:
thomwolf's avatar
thomwolf committed
111
112
113
114
115
        lr (float): learning rate. Default 1e-3.
        betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
        eps (float): Adams epsilon. Default: 1e-6
        weight_decay (float): Weight decay. Default: 0.0
        correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
116
    """
thomwolf's avatar
thomwolf committed
117
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
thomwolf's avatar
thomwolf committed
118
        if lr < 0.0:
thomwolf's avatar
thomwolf committed
119
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
thomwolf's avatar
thomwolf committed
120
121
122
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1]  < 1.0:
thomwolf's avatar
thomwolf committed
123
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
thomwolf's avatar
thomwolf committed
124
        if not 0.0 <= eps:
thomwolf's avatar
thomwolf committed
125
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
thomwolf's avatar
thomwolf committed
126
127
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                        correct_bias=correct_bias)
thomwolf's avatar
thomwolf committed
128
        super(AdamW, self).__init__(params, defaults)
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
thomwolf's avatar
thomwolf committed
155
                    state['exp_avg'] = torch.zeros_like(p.data)
156
                    # Exponential moving average of squared gradient values
thomwolf's avatar
thomwolf committed
157
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
158

thomwolf's avatar
thomwolf committed
159
160
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
161

thomwolf's avatar
thomwolf committed
162
                state['step'] += 1
163
164

                # Decay the first and second moment running average coefficient
thomwolf's avatar
thomwolf committed
165
                # In-place operations to update the averages at the same time
thomwolf's avatar
thomwolf committed
166
167
                exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
thomwolf's avatar
thomwolf committed
168
169
170
171
                denom = exp_avg_sq.sqrt().add_(group['eps'])

                step_size = group['lr']
                if group['correct_bias']:  # No bias correction for Bert
thomwolf's avatar
thomwolf committed
172
173
                    bias_correction1 = 1.0 - beta1 ** state['step']
                    bias_correction2 = 1.0 - beta2 ** state['step']
thomwolf's avatar
thomwolf committed
174
175
176
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)
177
178
179
180
181

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
thomwolf's avatar
thomwolf committed
182
                # Instead we want to decay the weights in a manner that doesn't interact
183
184
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
thomwolf's avatar
thomwolf committed
185
                # Add weight decay at the end (fixed version)
thomwolf's avatar
thomwolf committed
186
                if group['weight_decay'] > 0.0:
thomwolf's avatar
thomwolf committed
187
                    p.data.add_(-group['lr'] * group['weight_decay'], p.data)
188
189

        return loss