Commit 60a37238 authored by lukovnikov's avatar lukovnikov
Browse files

added warning

parent da2d8ca2
...@@ -19,6 +19,9 @@ import torch ...@@ -19,6 +19,9 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002): def warmup_cosine(x, warmup=0.002):
if x < warmup: if x < warmup:
...@@ -37,6 +40,8 @@ def warmup_linear(x, warmup=0.002): ...@@ -37,6 +40,8 @@ def warmup_linear(x, warmup=0.002):
After `t_total`-th training step, learning rate is zero. """ After `t_total`-th training step, learning rate is zero. """
if x < warmup: if x < warmup:
return x/warmup return x/warmup
if x > 1:
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
return max((x-1.)/(warmup-1.), 0) return max((x-1.)/(warmup-1.), 0)
SCHEDULES = { SCHEDULES = {
......
...@@ -19,6 +19,9 @@ import torch ...@@ -19,6 +19,9 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002): def warmup_cosine(x, warmup=0.002):
if x < warmup: if x < warmup:
...@@ -37,6 +40,8 @@ def warmup_linear(x, warmup=0.002): ...@@ -37,6 +40,8 @@ def warmup_linear(x, warmup=0.002):
After `t_total`-th training step, learning rate is zero. """ After `t_total`-th training step, learning rate is zero. """
if x < warmup: if x < warmup:
return x/warmup return x/warmup
if x > 1:
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
return max((x-1.)/(warmup-1.), 0) return max((x-1.)/(warmup-1.), 0)
SCHEDULES = { SCHEDULES = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment