"common/git@developer.sourcefind.cn:OpenDAS/llama.cpp.git" did not exist on "cfb179e7c9a905b2a520ab623df1ffe9510d4555"
Commit 969f4474 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Remove LAMB optimizer (at least until we can test it more)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1008

Differential Revision: D16763315

Pulled By: myleott

fbshipit-source-id: d4bad8384eec273f2d5de4ed29fb8d158ab9187c
parent 3bbdc554
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
LAMB optimizer from github.com/cybertronai/pytorch-lamb.
"""
import math
import torch
import torch.optim
from . import FairseqOptimizer, register_optimizer
@register_optimizer('lamb')
class FairseqLamb(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = Lamb(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B',
help='betas for LAMB optimizer')
parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D',
help='epsilon for LAMB optimizer')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'betas': eval(self.args.lamb_betas),
'eps': self.args.lamb_eps,
'weight_decay': self.args.weight_decay,
}
class Lamb(torch.optim.Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super(Lamb, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
adam_step = exp_avg / denom
# L2 norm uses sum, but here since we're dividing, use mean to avoid overflow.
r1 = p.data.pow(2).mean().sqrt()
r2 = adam_step.pow(2).mean().sqrt()
r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10)
state['r1'] = r1
state['r2'] = r2
state['r'] = r
if self.adam:
r = 1
p.data.add_(-step_size * r, adam_step)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment