Commit c1468202 authored by Tayo Oguntebi's avatar Tayo Oguntebi Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307918408
parent 22f10c60
...@@ -261,18 +261,15 @@ def get_callbacks(steps_per_epoch): ...@@ -261,18 +261,15 @@ def get_callbacks(steps_per_epoch):
return callbacks return callbacks
def build_stats(history, callbacks): def update_stats(history, stats, callbacks):
"""Normalizes and returns dictionary of stats. """Normalizes and updates dictionary of stats.
Args: Args:
history: Results of the training step. history: Results of the training step.
stats: Dict with pre-existing training stats.
callbacks: a list of callbacks which might include a time history callback callbacks: a list of callbacks which might include a time history callback
used during keras.fit. used during keras.fit.
Returns:
Dictionary of normalized results.
""" """
stats = {}
if history and history.history: if history and history.history:
train_hist = history.history train_hist = history.history
...@@ -280,7 +277,7 @@ def build_stats(history, callbacks): ...@@ -280,7 +277,7 @@ def build_stats(history, callbacks):
stats['loss'] = float(train_hist['loss'][-1]) stats['loss'] = float(train_hist['loss'][-1])
if not callbacks: if not callbacks:
return stats return
# Look for the time history callback which was used during keras.fit # Look for the time history callback which was used during keras.fit
for callback in callbacks: for callback in callbacks:
...@@ -293,4 +290,3 @@ def build_stats(history, callbacks): ...@@ -293,4 +290,3 @@ def build_stats(history, callbacks):
callback.batch_size * callback.log_steps * callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) / (len(callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp)) (timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats
...@@ -363,7 +363,8 @@ class TransformerTask(object): ...@@ -363,7 +363,8 @@ class TransformerTask(object):
stats = ({ stats = ({
"loss": train_loss "loss": train_loss
} if history is None else misc.build_stats(history, callbacks)) } if history is None else {})
misc.update_stats(history, stats, callbacks)
if uncased_score and cased_score: if uncased_score and cased_score:
stats["bleu_uncased"] = uncased_score stats["bleu_uncased"] = uncased_score
stats["bleu_cased"] = cased_score stats["bleu_cased"] = cased_score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment