Commit 070f3b21 authored by lukovnikov's avatar lukovnikov
Browse files

added warning

parent 46ef6460
...@@ -154,15 +154,15 @@ class BertAdam(Optimizer): ...@@ -154,15 +154,15 @@ class BertAdam(Optimizer):
if group['t_total'] != -1: if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']] schedule_fct = SCHEDULES[group['schedule']]
# warning for exceeding t_total (only active with warmup_linear
progress = state['step']/group['t_total'] progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if progress > 1. and progress > self._warned_for_t_total_at_progress: if progress > 1. and progress > self._warned_for_t_total_at_progress:
logger.warning( logger.warning(
"Training beyond specified 't_total' steps. Learning rate set to zero. " "Training beyond specified 't_total' steps. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(self.__class__.__name__)) "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
self._warned_for_t_total_at_progress = progress self._warned_for_t_total_at_progress = progress
# end warning # end warning
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
else: else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
......
...@@ -137,15 +137,15 @@ class OpenAIAdam(Optimizer): ...@@ -137,15 +137,15 @@ class OpenAIAdam(Optimizer):
if group['t_total'] != -1: if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']] schedule_fct = SCHEDULES[group['schedule']]
# warning for exceeding t_total (only active with warmup_linear
progress = state['step']/group['t_total'] progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if progress > 1. and progress > self._warned_for_t_total_at_progress: if progress > 1. and progress > self._warned_for_t_total_at_progress:
logger.warning( logger.warning(
"Training beyond specified 't_total' steps. Learning rate set to zero. " "Training beyond specified 't_total' steps. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(self.__class__.__name__)) "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
self._warned_for_t_total_at_progress = progress self._warned_for_t_total_at_progress = progress
# end warning # end warning
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
else: else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment