optimization.py 7.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""

17
18
19
import math
import torch
from torch.optim import Optimizer
Li Li's avatar
Li Li committed
20
from torch.optim.optimizer import required
21
from torch.nn.utils import clip_grad_norm_
lukovnikov's avatar
lukovnikov committed
22
23
24
import logging

logger = logging.getLogger(__name__)
25
26

def warmup_cosine(x, warmup=0.002):
thomwolf's avatar
thomwolf committed
27
28
29
    if x < warmup:
        return x/warmup
    return 0.5 * (1.0 + torch.cos(math.pi * x))
30
31

def warmup_constant(x, warmup=0.002):
32
33
    """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
        Learning rate is 1. afterwards. """
thomwolf's avatar
thomwolf committed
34
35
36
    if x < warmup:
        return x/warmup
    return 1.0
37
38

def warmup_linear(x, warmup=0.002):
39
40
    """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
        After `t_total`-th training step, learning rate is zero. """
thomwolf's avatar
thomwolf committed
41
42
    if x < warmup:
        return x/warmup
43
    return max((x-1.)/(warmup-1.), 0)
44
45

SCHEDULES = {
lukovnikov's avatar
lukovnikov committed
46
47
48
    'warmup_cosine':   warmup_cosine,
    'warmup_constant': warmup_constant,
    'warmup_linear':   warmup_linear,
49
50
51
}


thomwolf's avatar
thomwolf committed
52
class BertAdam(Optimizer):
thomwolf's avatar
thomwolf committed
53
    """Implements BERT version of Adam algorithm with weight decay fix.
thomwolf's avatar
thomwolf committed
54
    Params:
thomwolf's avatar
thomwolf committed
55
56
57
58
59
60
61
62
        lr: learning rate
        warmup: portion of t_total for the warmup, -1  means no warmup. Default: -1
        t_total: total number of training steps for the learning
            rate schedule, -1  means constant learning rate. Default: -1
        schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
        b1: Adams b1. Default: 0.9
        b2: Adams b2. Default: 0.999
        e: Adams epsilon. Default: 1e-6
63
        weight_decay: Weight decay. Default: 0.01
thomwolf's avatar
thomwolf committed
64
        max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
65
    """
Li Li's avatar
Li Li committed
66
    def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
67
                 b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
thomwolf's avatar
thomwolf committed
68
                 max_grad_norm=1.0):
Li Li's avatar
Li Li committed
69
        if lr is not required and lr < 0.0:
thomwolf's avatar
thomwolf committed
70
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
71
72
        if schedule not in SCHEDULES:
            raise ValueError("Invalid schedule parameter: {}".format(schedule))
thomwolf's avatar
thomwolf committed
73
74
        if not 0.0 <= warmup < 1.0 and not warmup == -1:
            raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
75
        if not 0.0 <= b1 < 1.0:
thomwolf's avatar
thomwolf committed
76
            raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
77
        if not 0.0 <= b2 < 1.0:
thomwolf's avatar
thomwolf committed
78
79
80
            raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
        if not e >= 0.0:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
81
        defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
82
                        b1=b1, b2=b2, e=e, weight_decay=weight_decay,
83
                        max_grad_norm=max_grad_norm)
thomwolf's avatar
thomwolf committed
84
        super(BertAdam, self).__init__(params, defaults)
85
86
87
88
89
90
91
92

    def get_lr(self):
        lr = []
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if len(state) == 0:
                    return [0]
thomwolf's avatar
thomwolf committed
93
94
95
96
97
                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
                    lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
                else:
                    lr_scheduled = group['lr']
98
99
100
101
102
103
104
105
106
107
108
109
110
111
                lr.append(lr_scheduled)
        return lr

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

lukovnikov's avatar
lukovnikov committed
112
113
        warned_for_t_total = False

114
115
116
117
118
119
120
121
122
123
124
125
126
127
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
thomwolf's avatar
thomwolf committed
128
                    state['next_m'] = torch.zeros_like(p.data)
129
                    # Exponential moving average of squared gradient values
thomwolf's avatar
thomwolf committed
130
                    state['next_v'] = torch.zeros_like(p.data)
131

thomwolf's avatar
thomwolf committed
132
                next_m, next_v = state['next_m'], state['next_v']
133
134
135
136
137
138
139
                beta1, beta2 = group['b1'], group['b2']

                # Add grad clipping
                if group['max_grad_norm'] > 0:
                    clip_grad_norm_(p, group['max_grad_norm'])

                # Decay the first and second moment running average coefficient
thomwolf's avatar
thomwolf committed
140
141
142
143
                # In-place operations to update the averages at the same time
                next_m.mul_(beta1).add_(1 - beta1, grad)
                next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                update = next_m / (next_v.sqrt() + group['e'])
144
145
146
147
148

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
thomwolf's avatar
thomwolf committed
149
                # Instead we want to decay the weights in a manner that doesn't interact
150
151
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
152
153
                if group['weight_decay'] > 0.0:
                    update += group['weight_decay'] * p.data
thomwolf's avatar
thomwolf committed
154
155
156

                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
lukovnikov's avatar
lukovnikov committed
157
                    progress = state['step']/group['t_total']
lukovnikov's avatar
lukovnikov committed
158
159
                    lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
                    # warning for exceeding t_total (only active with warmup_linear
lukovnikov's avatar
lukovnikov committed
160
                    if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
lukovnikov's avatar
lukovnikov committed
161
                        logger.warning(
lukovnikov's avatar
lukovnikov committed
162
163
                            "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
                            "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
lukovnikov's avatar
lukovnikov committed
164
                        warned_for_t_total = True
lukovnikov's avatar
lukovnikov committed
165
                    # end warning
thomwolf's avatar
thomwolf committed
166
167
168
169
170
171
172
173
174
                else:
                    lr_scheduled = group['lr']

                update_with_lr = lr_scheduled * update
                p.data.add_(-update_with_lr)

                state['step'] += 1

                # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
thomwolf's avatar
thomwolf committed
175
                # No bias correction
thomwolf's avatar
thomwolf committed
176
177
                # bias_correction1 = 1 - beta1 ** state['step']
                # bias_correction2 = 1 - beta2 ** state['step']
178
179

        return loss