Unverified Commit 7b9e5a54 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #327 from lukovnikov/master

Issue#324: warmup linear fixes
parents 4784b04f 35410da7
...@@ -19,6 +19,9 @@ import torch ...@@ -19,6 +19,9 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002): def warmup_cosine(x, warmup=0.002):
if x < warmup: if x < warmup:
...@@ -26,19 +29,23 @@ def warmup_cosine(x, warmup=0.002): ...@@ -26,19 +29,23 @@ def warmup_cosine(x, warmup=0.002):
return 0.5 * (1.0 + torch.cos(math.pi * x)) return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002): def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup: if x < warmup:
return x/warmup return x/warmup
return 1.0 return 1.0
def warmup_linear(x, warmup=0.002): def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup: if x < warmup:
return x/warmup return x/warmup
return 1.0 - x return max((x-1.)/(warmup-1.), 0)
SCHEDULES = { SCHEDULES = {
'warmup_cosine':warmup_cosine, 'warmup_cosine': warmup_cosine,
'warmup_constant':warmup_constant, 'warmup_constant': warmup_constant,
'warmup_linear':warmup_linear, 'warmup_linear': warmup_linear,
} }
...@@ -102,6 +109,8 @@ class BertAdam(Optimizer): ...@@ -102,6 +109,8 @@ class BertAdam(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
warned_for_t_total = False
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -145,7 +154,15 @@ class BertAdam(Optimizer): ...@@ -145,7 +154,15 @@ class BertAdam(Optimizer):
if group['t_total'] != -1: if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']] schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else: else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
......
...@@ -19,18 +19,28 @@ import torch ...@@ -19,18 +19,28 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002): def warmup_cosine(x, warmup=0.002):
s = 1 if x <= warmup else 0 if x < warmup:
return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x))) return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002): def warmup_constant(x, warmup=0.002):
s = 1 if x <= warmup else 0 """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
return s*(x/warmup) + (1-s)*1 Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002): def warmup_linear(x, warmup=0.002):
s = 1 if x <= warmup else 0 """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
return (s*(x/warmup) + (1-s))*(1-x) After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = { SCHEDULES = {
'warmup_cosine':warmup_cosine, 'warmup_cosine':warmup_cosine,
...@@ -88,6 +98,8 @@ class OpenAIAdam(Optimizer): ...@@ -88,6 +98,8 @@ class OpenAIAdam(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
warned_for_t_total = False
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -125,7 +137,15 @@ class OpenAIAdam(Optimizer): ...@@ -125,7 +137,15 @@ class OpenAIAdam(Optimizer):
if group['t_total'] != -1: if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']] schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else: else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment