"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "df146c41e27e6e15c0ab51af9070c87c82aa4fc7"
Commit b95f1b5d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add LAMB optimizer

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/572

Differential Revision: D15317928

Pulled By: myleott

fbshipit-source-id: b3f0e9229737a63b49937e7c5b918470f18ddc45
parent d0577ba7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
LAMB optimizer from github.com/cybertronai/pytorch-lamb.
"""
import math
import torch
import torch.optim
from . import FairseqOptimizer, register_optimizer
@register_optimizer('lamb')
class FairseqLamb(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = Lamb(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'betas': eval(self.args.lamb_betas),
'eps': self.args.lamb_eps,
'weight_decay': self.args.weight_decay,
}
class Lamb(torch.optim.Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super(Lamb, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
adam_step = exp_avg / denom
# L2 norm uses sum, but here since we're dividing, use mean to avoid overflow.
r1 = p.data.pow(2).mean().sqrt()
r2 = adam_step.pow(2).mean().sqrt()
r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10)
state['r1'] = r1
state['r2'] = r2
state['r'] = r
if self.adam:
r = 1
p.data.add_(-step_size * r, adam_step)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment