Commit 1758c8fc authored by lukovnikov's avatar lukovnikov
Browse files

- updated docs for optimization

parent 725a5632
...@@ -25,12 +25,18 @@ logger = logging.getLogger(__name__) ...@@ -25,12 +25,18 @@ logger = logging.getLogger(__name__)
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", __all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
"WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"] "WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"]
class LRSchedule(object): class LRSchedule(object):
warn_t_total = False """ Parent of all LRSchedules here. """
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
def __init__(self, warmup=0.002, t_total=-1, **kw): def __init__(self, warmup=0.002, t_total=-1, **kw):
"""
:param warmup: what fraction of t_total steps will be used for linear warmup
:param t_total: how many training steps (updates) are planned
:param kw:
"""
super(LRSchedule, self).__init__(**kw) super(LRSchedule, self).__init__(**kw)
self.warmup, self.t_total = warmup, t_total self.warmup, self.t_total = warmup, t_total
if t_total <= 0: if t_total <= 0:
...@@ -40,6 +46,11 @@ class LRSchedule(object): ...@@ -40,6 +46,11 @@ class LRSchedule(object):
self.warned_for_t_total_at_progress = -1 self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False): def get_lr(self, step, nowarn=False):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
progress = step / self.t_total progress = step / self.t_total
ret = self.get_lr_(progress) ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear # warning for exceeding t_total (only active with warmup_linear
...@@ -51,14 +62,27 @@ class LRSchedule(object): ...@@ -51,14 +62,27 @@ class LRSchedule(object):
# end warning # end warning
return ret return ret
def get_lr_(self, step): def get_lr_(self, progress):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return 1. return 1.
# raise NotImplemented("use subclass") - # raise NotImplemented("use subclass") -
class WarmupCosineSchedule(LRSchedule): class WarmupCosineSchedule(LRSchedule):
"""
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
"""
warn_t_total = True warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
self.cycles = cycles self.cycles = cycles
...@@ -73,10 +97,12 @@ class WarmupCosineSchedule(LRSchedule): ...@@ -73,10 +97,12 @@ class WarmupCosineSchedule(LRSchedule):
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
class WarmupMultiCosineSchedule(WarmupCosineSchedule): class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
warn_t_total = True """
Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1).
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.) assert(cycles >= 1.)
def get_lr_(self, progress): def get_lr_(self, progress):
...@@ -90,7 +116,16 @@ class WarmupMultiCosineSchedule(WarmupCosineSchedule): ...@@ -90,7 +116,16 @@ class WarmupMultiCosineSchedule(WarmupCosineSchedule):
return ret return ret
class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule): class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
"""
Cosine learning rate schedule with linear warmups and linear warmup restarts.
The same warmup rate is used for warmup restarts as for initial warmup.
The total effective fraction of warmup steps over all cycles is warmup * cycles!
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.)
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw)
def get_lr_(self, progress): def get_lr_(self, progress):
if self.t_total <= 0.: if self.t_total <= 0.:
return 1. return 1.
...@@ -104,7 +139,9 @@ class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule): ...@@ -104,7 +139,9 @@ class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
class WarmupConstantSchedule(LRSchedule): class WarmupConstantSchedule(LRSchedule):
warn_t_total = False """
Applies linear warmup. After warmup always returns 1..
"""
def get_lr_(self, progress): def get_lr_(self, progress):
if progress < self.warmup: if progress < self.warmup:
return progress / self.warmup return progress / self.warmup
...@@ -112,6 +149,9 @@ class WarmupConstantSchedule(LRSchedule): ...@@ -112,6 +149,9 @@ class WarmupConstantSchedule(LRSchedule):
class WarmupLinearSchedule(LRSchedule): class WarmupLinearSchedule(LRSchedule):
"""
Linear warmup. Linear decay after warmup.
"""
warn_t_total = True warn_t_total = True
def get_lr_(self, progress): def get_lr_(self, progress):
if progress < self.warmup: if progress < self.warmup:
...@@ -145,8 +185,7 @@ class BertAdam(Optimizer): ...@@ -145,8 +185,7 @@ class BertAdam(Optimizer):
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
""" """
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0., b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
max_grad_norm=1.0):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
...@@ -163,9 +202,10 @@ class BertAdam(Optimizer): ...@@ -163,9 +202,10 @@ class BertAdam(Optimizer):
schedule = schedule_type(warmup=warmup, t_total=t_total) schedule = schedule_type(warmup=warmup, t_total=t_total)
else: else:
if warmup != -1 or t_total != -1: if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.") logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule, defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay, b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults) super(BertAdam, self).__init__(params, defaults)
...@@ -176,10 +216,8 @@ class BertAdam(Optimizer): ...@@ -176,10 +216,8 @@ class BertAdam(Optimizer):
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
return [0] return [0]
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step']) lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled) lr.append(lr_scheduled)
return lr return lr
...@@ -235,8 +273,6 @@ class BertAdam(Optimizer): ...@@ -235,8 +273,6 @@ class BertAdam(Optimizer):
if group['weight_decay'] > 0.0: if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data update += group['weight_decay'] * p.data
# TODO: init weight decay
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step']) lr_scheduled *= group['schedule'].get_lr(state['step'])
......
...@@ -20,35 +20,10 @@ from torch.optim import Optimizer ...@@ -20,35 +20,10 @@ from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging import logging
from .optimization import *
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
x_ = (x - warmup) / (1 - warmup) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * x_))
def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
}
class OpenAIAdam(Optimizer): class OpenAIAdam(Optimizer):
"""Implements Open AI version of Adam algorithm with weight decay fix. """Implements Open AI version of Adam algorithm with weight decay fix.
...@@ -58,17 +33,23 @@ class OpenAIAdam(Optimizer): ...@@ -58,17 +33,23 @@ class OpenAIAdam(Optimizer):
vector_l2=False, max_grad_norm=-1, **kwargs): vector_l2=False, max_grad_norm=-1, **kwargs):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES: if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule)) raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0: if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {}".format(b1)) raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0: if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {}".format(b2)) raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0: if not e >= 0.0:
raise ValueError("Invalid epsilon value: {}".format(e)) raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, # initialize schedule object
if not isinstance(schedule, LRSchedule):
schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults) super(OpenAIAdam, self).__init__(params, defaults)
...@@ -80,11 +61,8 @@ class OpenAIAdam(Optimizer): ...@@ -80,11 +61,8 @@ class OpenAIAdam(Optimizer):
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
return [0] return [0]
if group['t_total'] != -1: lr_scheduled = group['lr']
schedule_fct = SCHEDULES[group['schedule']] lr_scheduled *= group['schedule'].get_lr(state['step'])
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled) lr.append(lr_scheduled)
return lr return lr
...@@ -99,8 +77,6 @@ class OpenAIAdam(Optimizer): ...@@ -99,8 +77,6 @@ class OpenAIAdam(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
warned_for_t_total = False
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -136,19 +112,8 @@ class OpenAIAdam(Optimizer): ...@@ -136,19 +112,8 @@ class OpenAIAdam(Optimizer):
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state['step']
if group['t_total'] != -1: lr_scheduled = group['lr']
schedule_fct = SCHEDULES[group['schedule']] lr_scheduled *= group['schedule'].get_lr(state['step'])
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr']
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment