Commit e04bab59 authored by lukovnikov's avatar lukovnikov
Browse files

fix for negative learning rate with warmup_linear in BertAdam (happens when...

fix for negative learning rate with warmup_linear in BertAdam (happens when t_total is specified incorrectly)
+ copied BERT optimization warmup functions to OpenAI optimization file + added comments
parent 2152bfea
...@@ -26,14 +26,18 @@ def warmup_cosine(x, warmup=0.002): ...@@ -26,14 +26,18 @@ def warmup_cosine(x, warmup=0.002):
return 0.5 * (1.0 + torch.cos(math.pi * x)) return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002): def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup: if x < warmup:
return x/warmup return x/warmup
return 1.0 return 1.0
def warmup_linear(x, warmup=0.002): def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup: if x < warmup:
return x/warmup return x/warmup
return 1.0 - x return max(1.0 - x, 0)
SCHEDULES = { SCHEDULES = {
'warmup_cosine':warmup_cosine, 'warmup_cosine':warmup_cosine,
......
...@@ -21,16 +21,23 @@ from torch.optim.optimizer import required ...@@ -21,16 +21,23 @@ from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
def warmup_cosine(x, warmup=0.002): def warmup_cosine(x, warmup=0.002):
s = 1 if x <= warmup else 0 if x < warmup:
return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x))) return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002): def warmup_constant(x, warmup=0.002):
s = 1 if x <= warmup else 0 """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
return s*(x/warmup) + (1-s)*1 Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002): def warmup_linear(x, warmup=0.002):
s = 1 if x <= warmup else 0 """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
return (s*(x/warmup) + (1-s))*(1-x) After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max(1.0 - x, 0)
SCHEDULES = { SCHEDULES = {
'warmup_cosine':warmup_cosine, 'warmup_cosine':warmup_cosine,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment