optimization.py 6.67 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""

17
18
19
import math
import torch
from torch.optim import Optimizer
Li Li's avatar
Li Li committed
20
from torch.optim.optimizer import required
21
22
23
from torch.nn.utils import clip_grad_norm_

def warmup_cosine(x, warmup=0.002):
thomwolf's avatar
thomwolf committed
24
25
26
    if x < warmup:
        return x/warmup
    return 0.5 * (1.0 + torch.cos(math.pi * x))
27
28

def warmup_constant(x, warmup=0.002):
thomwolf's avatar
thomwolf committed
29
30
31
    if x < warmup:
        return x/warmup
    return 1.0
32
33

def warmup_linear(x, warmup=0.002):
thomwolf's avatar
thomwolf committed
34
35
36
    if x < warmup:
        return x/warmup
    return 1.0 - x
37
38
39
40
41
42
43
44

SCHEDULES = {
    'warmup_cosine':warmup_cosine,
    'warmup_constant':warmup_constant,
    'warmup_linear':warmup_linear,
}


thomwolf's avatar
thomwolf committed
45
class BertAdam(Optimizer):
thomwolf's avatar
thomwolf committed
46
    """Implements BERT version of Adam algorithm with weight decay fix.
thomwolf's avatar
thomwolf committed
47
    Params:
thomwolf's avatar
thomwolf committed
48
49
50
51
52
53
54
55
56
57
        lr: learning rate
        warmup: portion of t_total for the warmup, -1  means no warmup. Default: -1
        t_total: total number of training steps for the learning
            rate schedule, -1  means constant learning rate. Default: -1
        schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
        b1: Adams b1. Default: 0.9
        b2: Adams b2. Default: 0.999
        e: Adams epsilon. Default: 1e-6
        weight_decay_rate: Weight decay. Default: 0.01
        max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
58
    """
Li Li's avatar
Li Li committed
59
    def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
thomwolf's avatar
thomwolf committed
60
61
                 b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
                 max_grad_norm=1.0):
Li Li's avatar
Li Li committed
62
        if lr is not required and lr < 0.0:
thomwolf's avatar
thomwolf committed
63
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
64
65
        if schedule not in SCHEDULES:
            raise ValueError("Invalid schedule parameter: {}".format(schedule))
thomwolf's avatar
thomwolf committed
66
67
        if not 0.0 <= warmup < 1.0 and not warmup == -1:
            raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
68
        if not 0.0 <= b1 < 1.0:
thomwolf's avatar
thomwolf committed
69
            raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
70
        if not 0.0 <= b2 < 1.0:
thomwolf's avatar
thomwolf committed
71
72
73
            raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
        if not e >= 0.0:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
74
        defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
thomwolf's avatar
thomwolf committed
75
                        b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
76
                        max_grad_norm=max_grad_norm)
thomwolf's avatar
thomwolf committed
77
        super(BertAdam, self).__init__(params, defaults)
78
79
80
81
82
83
84
85

    def get_lr(self):
        lr = []
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if len(state) == 0:
                    return [0]
thomwolf's avatar
thomwolf committed
86
87
88
89
90
                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
                    lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
                else:
                    lr_scheduled = group['lr']
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                lr.append(lr_scheduled)
        return lr

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
thomwolf's avatar
thomwolf committed
119
                    state['next_m'] = torch.zeros_like(p.data)
120
                    # Exponential moving average of squared gradient values
thomwolf's avatar
thomwolf committed
121
                    state['next_v'] = torch.zeros_like(p.data)
122

thomwolf's avatar
thomwolf committed
123
                next_m, next_v = state['next_m'], state['next_v']
124
125
126
127
128
129
130
                beta1, beta2 = group['b1'], group['b2']

                # Add grad clipping
                if group['max_grad_norm'] > 0:
                    clip_grad_norm_(p, group['max_grad_norm'])

                # Decay the first and second moment running average coefficient
thomwolf's avatar
thomwolf committed
131
132
133
134
                # In-place operations to update the averages at the same time
                next_m.mul_(beta1).add_(1 - beta1, grad)
                next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                update = next_m / (next_v.sqrt() + group['e'])
135
136
137
138
139

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
thomwolf's avatar
thomwolf committed
140
                # Instead we want to decay the weights in a manner that doesn't interact
141
142
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
thomwolf's avatar
thomwolf committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                if group['weight_decay_rate'] > 0.0:
                    update += group['weight_decay_rate'] * p.data

                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
                    lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
                else:
                    lr_scheduled = group['lr']

                update_with_lr = lr_scheduled * update
                p.data.add_(-update_with_lr)

                state['step'] += 1

                # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
thomwolf's avatar
thomwolf committed
158
                # No bias correction
thomwolf's avatar
thomwolf committed
159
160
                # bias_correction1 = 1 - beta1 ** state['step']
                # bias_correction2 = 1 - beta2 ** state['step']
161
162

        return loss