Commit 4d79e0d3 authored by lukovnikov's avatar lukovnikov
Browse files

added warning

parent 66a84b63
...@@ -82,8 +82,6 @@ class BertAdam(Optimizer): ...@@ -82,8 +82,6 @@ class BertAdam(Optimizer):
b1=b1, b2=b2, e=e, weight_decay=weight_decay, b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults) super(BertAdam, self).__init__(params, defaults)
# warning for t_total exceeded
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") # warning is not active with other schedules (since it doesn't break them)
def get_lr(self): def get_lr(self):
lr = [] lr = []
...@@ -111,6 +109,8 @@ class BertAdam(Optimizer): ...@@ -111,6 +109,8 @@ class BertAdam(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
warned_for_t_total = False
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -157,11 +157,11 @@ class BertAdam(Optimizer): ...@@ -157,11 +157,11 @@ class BertAdam(Optimizer):
progress = state['step']/group['t_total'] progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear # warning for exceeding t_total (only active with warmup_linear
if progress > 1. and progress > self._warned_for_t_total_at_progress: if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning( logger.warning(
"Training beyond specified 't_total' steps. Learning rate set to {}. " "Training beyond specified 't_total' steps. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__)) "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
self._warned_for_t_total_at_progress = progress warned_for_t_total = True
# end warning # end warning
else: else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
......
...@@ -71,8 +71,6 @@ class OpenAIAdam(Optimizer): ...@@ -71,8 +71,6 @@ class OpenAIAdam(Optimizer):
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults) super(OpenAIAdam, self).__init__(params, defaults)
# warning for t_total exceeded
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") # warning is not active with other schedules (since it doesn't break them)
def get_lr(self): def get_lr(self):
lr = [] lr = []
...@@ -100,6 +98,8 @@ class OpenAIAdam(Optimizer): ...@@ -100,6 +98,8 @@ class OpenAIAdam(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
warned_for_t_total = False
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -140,11 +140,11 @@ class OpenAIAdam(Optimizer): ...@@ -140,11 +140,11 @@ class OpenAIAdam(Optimizer):
progress = state['step']/group['t_total'] progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear # warning for exceeding t_total (only active with warmup_linear
if progress > 1. and progress > self._warned_for_t_total_at_progress: if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning( logger.warning(
"Training beyond specified 't_total' steps. Learning rate set to {}. " "Training beyond specified 't_total' steps. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__)) "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
self._warned_for_t_total_at_progress = progress warned_for_t_total = True
# end warning # end warning
else: else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment