Unverified Commit 7b9e5a54 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #327 from lukovnikov/master

Issue#324: warmup linear fixes
parents 4784b04f 35410da7
......@@ -19,6 +19,9 @@ import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
if x < warmup:
......@@ -26,19 +29,23 @@ def warmup_cosine(x, warmup=0.002):
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return 1.0 - x
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
'warmup_cosine': warmup_cosine,
'warmup_constant': warmup_constant,
'warmup_linear': warmup_linear,
}
......@@ -102,6 +109,8 @@ class BertAdam(Optimizer):
if closure is not None:
loss = closure()
warned_for_t_total = False
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
......@@ -145,7 +154,15 @@ class BertAdam(Optimizer):
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr']
......
......@@ -19,18 +19,28 @@ import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
s = 1 if x <= warmup else 0
return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x)))
if x < warmup:
return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
s = 1 if x <= warmup else 0
return s*(x/warmup) + (1-s)*1
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
s = 1 if x <= warmup else 0
return (s*(x/warmup) + (1-s))*(1-x)
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine':warmup_cosine,
......@@ -88,6 +98,8 @@ class OpenAIAdam(Optimizer):
if closure is not None:
loss = closure()
warned_for_t_total = False
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
......@@ -125,7 +137,15 @@ class OpenAIAdam(Optimizer):
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment