"vscode:/vscode.git/clone" did not exist on "0dd5cc753bdc23f3fc1c9ca82170efde4d0c68aa"
optimization_openai.py 6.56 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for OpenAI GPT model."""

thomwolf's avatar
thomwolf committed
17
18
19
import math
import torch
from torch.optim import Optimizer
thomwolf's avatar
thomwolf committed
20
from torch.optim.optimizer import required
thomwolf's avatar
thomwolf committed
21
from torch.nn.utils import clip_grad_norm_
lukovnikov's avatar
lukovnikov committed
22
23
24
import logging

logger = logging.getLogger(__name__)
thomwolf's avatar
thomwolf committed
25
26

def warmup_cosine(x, warmup=0.002):
27
28
29
    if x < warmup:
        return x/warmup
    return 0.5 * (1.0 + torch.cos(math.pi * x))
thomwolf's avatar
thomwolf committed
30
31

def warmup_constant(x, warmup=0.002):
lukovnikov's avatar
lukovnikov committed
32
    """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
33
34
35
36
        Learning rate is 1. afterwards. """
    if x < warmup:
        return x/warmup
    return 1.0
thomwolf's avatar
thomwolf committed
37
38

def warmup_linear(x, warmup=0.002):
lukovnikov's avatar
lukovnikov committed
39
    """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
40
41
42
        After `t_total`-th training step, learning rate is zero. """
    if x < warmup:
        return x/warmup
43
    return max((x-1.)/(warmup-1.), 0)
thomwolf's avatar
thomwolf committed
44
45
46
47
48
49
50
51
52
53
54

SCHEDULES = {
    'warmup_cosine':warmup_cosine,
    'warmup_constant':warmup_constant,
    'warmup_linear':warmup_linear,
}


class OpenAIAdam(Optimizer):
    """Implements Open AI version of Adam algorithm with weight decay fix.
    """
thomwolf's avatar
thomwolf committed
55
56
    def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
                 b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
thomwolf's avatar
thomwolf committed
57
                 vector_l2=False, max_grad_norm=-1, **kwargs):
thomwolf's avatar
thomwolf committed
58
59
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
thomwolf's avatar
thomwolf committed
60
61
        if schedule not in SCHEDULES:
            raise ValueError("Invalid schedule parameter: {}".format(schedule))
thomwolf's avatar
thomwolf committed
62
63
        if not 0.0 <= warmup < 1.0 and not warmup == -1:
            raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
thomwolf's avatar
thomwolf committed
64
65
66
67
        if not 0.0 <= b1 < 1.0:
            raise ValueError("Invalid b1 parameter: {}".format(b1))
        if not 0.0 <= b2 < 1.0:
            raise ValueError("Invalid b2 parameter: {}".format(b2))
thomwolf's avatar
thomwolf committed
68
        if not e >= 0.0:
thomwolf's avatar
thomwolf committed
69
70
            raise ValueError("Invalid epsilon value: {}".format(e))
        defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
thomwolf's avatar
thomwolf committed
71
                        b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
thomwolf's avatar
thomwolf committed
72
73
74
                        max_grad_norm=max_grad_norm)
        super(OpenAIAdam, self).__init__(params, defaults)

thomwolf's avatar
thomwolf committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    def get_lr(self):
        lr = []
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if len(state) == 0:
                    return [0]
                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
                    lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
                else:
                    lr_scheduled = group['lr']
                lr.append(lr_scheduled)
        return lr

thomwolf's avatar
thomwolf committed
90
91
92
93
94
95
96
97
98
99
100
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

lukovnikov's avatar
lukovnikov committed
101
102
        warned_for_t_total = False

thomwolf's avatar
thomwolf committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['b1'], group['b2']

                state['step'] += 1

                # Add grad clipping
                if group['max_grad_norm'] > 0:
                    clip_grad_norm_(p, group['max_grad_norm'])

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                denom = exp_avg_sq.sqrt().add_(group['e'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

thomwolf's avatar
thomwolf committed
138
139
                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
lukovnikov's avatar
lukovnikov committed
140
                    progress = state['step']/group['t_total']
lukovnikov's avatar
lukovnikov committed
141
142
                    lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
                    # warning for exceeding t_total (only active with warmup_linear
lukovnikov's avatar
lukovnikov committed
143
                    if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
lukovnikov's avatar
lukovnikov committed
144
                        logger.warning(
lukovnikov's avatar
lukovnikov committed
145
146
                            "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
                            "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
lukovnikov's avatar
lukovnikov committed
147
                        warned_for_t_total = True
lukovnikov's avatar
lukovnikov committed
148
                    # end warning
thomwolf's avatar
thomwolf committed
149
150
151
                else:
                    lr_scheduled = group['lr']

thomwolf's avatar
thomwolf committed
152
153
154
155
156
                step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

                # Add weight decay at the end (fixed version)
thomwolf's avatar
thomwolf committed
157
158
                if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
                    p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
thomwolf's avatar
thomwolf committed
159
160

        return loss