"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ef0e9d806c51059b07b98cb0279a20d3ba3cbc1d"
Commit 90a41dbe authored by lukovnikov's avatar lukovnikov
Browse files

BertAdam schedule objects

parent 88874f6c
...@@ -18,7 +18,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model, ...@@ -18,7 +18,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2)
from .optimization import BertAdam from .optimization import *
from .optimization_openai import OpenAIAdam from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
...@@ -24,6 +24,9 @@ import logging ...@@ -24,6 +24,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam"]
class LRSchedule(object): class LRSchedule(object):
warn_t_total = False warn_t_total = False
def __init__(self, warmup=0.002, t_total=-1, **kw): def __init__(self, warmup=0.002, t_total=-1, **kw):
...@@ -83,32 +86,7 @@ class WarmupLinearSchedule(LRSchedule): ...@@ -83,32 +86,7 @@ class WarmupLinearSchedule(LRSchedule):
if progress < self.warmup: if progress < self.warmup:
return progress / self.warmup return progress / self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0) return max((progress - 1.) / (self.warmup - 1.), 0)
#
#
# def warmup_cosine(x, warmup=0.002):
# if x < warmup:
# return x/warmup
# return 0.5 * (1.0 + torch.cos(math.pi * x))
#
# def warmup_constant(x, warmup=0.002):
# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
# Learning rate is 1. afterwards. """
# if x < warmup:
# return x/warmup
# return 1.0
#
# def warmup_linear(x, warmup=0.002):
# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
# After `t_total`-th training step, learning rate is zero. """
# if x < warmup:
# return x/warmup
# return max((x-1.)/(warmup-1.), 0)
#
# SCHEDULES = {
# 'warmup_cosine': warmup_cosine,
# 'warmup_constant': warmup_constant,
# 'warmup_linear': warmup_linear,
# }
SCHEDULES = { SCHEDULES = {
None: LRSchedule, None: LRSchedule,
...@@ -126,7 +104,9 @@ class BertAdam(Optimizer): ...@@ -126,7 +104,9 @@ class BertAdam(Optimizer):
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1 rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' schedule: schedule to use for the warmup (see above).
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Default: 'warmup_linear'
b1: Adams b1. Default: 0.9 b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999 b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6 e: Adams epsilon. Default: 1e-6
...@@ -147,9 +127,13 @@ class BertAdam(Optimizer): ...@@ -147,9 +127,13 @@ class BertAdam(Optimizer):
if not e >= 0.0: if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
# initialize schedule object # initialize schedule object
schedule_type = SCHEDULES[schedule] if not isinstance(schedule, LRSchedule):
sched = schedule_type(warmup=warmup, t_total=t_total) schedule_type = SCHEDULES[schedule]
defaults = dict(lr=lr, schedule=sched, schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults) super(BertAdam, self).__init__(params, defaults)
...@@ -163,7 +147,7 @@ class BertAdam(Optimizer): ...@@ -163,7 +147,7 @@ class BertAdam(Optimizer):
return [0] return [0]
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'](state['step']) lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled) lr.append(lr_scheduled)
return lr return lr
...@@ -221,7 +205,7 @@ class BertAdam(Optimizer): ...@@ -221,7 +205,7 @@ class BertAdam(Optimizer):
update += group['weight_decay'] * p.data update += group['weight_decay'] * p.data
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'](state['step']) lr_scheduled *= group['schedule'].get_lr(state['step'])
update_with_lr = lr_scheduled * update update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr) p.data.add_(-update_with_lr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment