"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "11792d7826854979bb532b6da09bc3796b09ea6a"
Commit 1758c8fc authored by lukovnikov's avatar lukovnikov
Browse files

- updated docs for optimization

parent 725a5632
......@@ -25,12 +25,18 @@ logger = logging.getLogger(__name__)
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
"WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"]
"WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"]
class LRSchedule(object):
warn_t_total = False
""" Parent of all LRSchedules here. """
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
def __init__(self, warmup=0.002, t_total=-1, **kw):
"""
:param warmup: what fraction of t_total steps will be used for linear warmup
:param t_total: how many training steps (updates) are planned
:param kw:
"""
super(LRSchedule, self).__init__(**kw)
self.warmup, self.t_total = warmup, t_total
if t_total <= 0:
......@@ -40,6 +46,11 @@ class LRSchedule(object):
self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
progress = step / self.t_total
ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear
......@@ -51,14 +62,27 @@ class LRSchedule(object):
# end warning
return ret
def get_lr_(self, step):
def get_lr_(self, progress):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return 1.
# raise NotImplemented("use subclass") -
class WarmupCosineSchedule(LRSchedule):
"""
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
"""
warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
self.cycles = cycles
......@@ -73,10 +97,12 @@ class WarmupCosineSchedule(LRSchedule):
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
class WarmupMultiCosineSchedule(WarmupCosineSchedule):
warn_t_total = True
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
"""
Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1).
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.)
def get_lr_(self, progress):
......@@ -90,7 +116,16 @@ class WarmupMultiCosineSchedule(WarmupCosineSchedule):
return ret
class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
"""
Cosine learning rate schedule with linear warmups and linear warmup restarts.
The same warmup rate is used for warmup restarts as for initial warmup.
The total effective fraction of warmup steps over all cycles is warmup * cycles!
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.)
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw)
def get_lr_(self, progress):
if self.t_total <= 0.:
return 1.
......@@ -104,7 +139,9 @@ class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
class WarmupConstantSchedule(LRSchedule):
warn_t_total = False
"""
Applies linear warmup. After warmup always returns 1..
"""
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
......@@ -112,6 +149,9 @@ class WarmupConstantSchedule(LRSchedule):
class WarmupLinearSchedule(LRSchedule):
"""
Linear warmup. Linear decay after warmup.
"""
warn_t_total = True
def get_lr_(self, progress):
if progress < self.warmup:
......@@ -145,8 +185,7 @@ class BertAdam(Optimizer):
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0.,
max_grad_norm=1.0):
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
......@@ -163,9 +202,10 @@ class BertAdam(Optimizer):
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
......@@ -176,10 +216,8 @@ class BertAdam(Optimizer):
state = self.state[p]
if len(state) == 0:
return [0]
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled)
return lr
......@@ -235,8 +273,6 @@ class BertAdam(Optimizer):
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
# TODO: init weight decay
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
......
......@@ -20,35 +20,10 @@ from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
from .optimization import *
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
x_ = (x - warmup) / (1 - warmup) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * x_))
def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
}
class OpenAIAdam(Optimizer):
"""Implements Open AI version of Adam algorithm with weight decay fix.
......@@ -58,17 +33,23 @@ class OpenAIAdam(Optimizer):
vector_l2=False, max_grad_norm=-1, **kwargs):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {}".format(b1))
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {}".format(b2))
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {}".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
# initialize schedule object
if not isinstance(schedule, LRSchedule):
schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults)
......@@ -80,11 +61,8 @@ class OpenAIAdam(Optimizer):
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled)
return lr
......@@ -99,8 +77,6 @@ class OpenAIAdam(Optimizer):
if closure is not None:
loss = closure()
warned_for_t_total = False
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
......@@ -136,19 +112,8 @@ class OpenAIAdam(Optimizer):
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr']
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment