Unverified Commit dc8c6ce1 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7209)

257883986  by hongkuny<hongkuny@google.com>:

    Adds tf.summary for bert training

--

PiperOrigin-RevId: 257883986
parent fe748d4a
......@@ -24,7 +24,8 @@ import os
from absl import logging
import tensorflow as tf
SUMMARY_TXT = 'training_summary.txt'
_SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10
def get_primary_cpu_task(use_remote_tpu=False):
......@@ -76,6 +77,14 @@ def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
return steps_per_loop
def _write_txt_summary(training_summary, model_dir):
"""Writes a summary text file to record stats."""
summary_path = os.path.join(model_dir, _SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
def run_customized_training_loop(
# pylint: disable=invalid-name
_sentinel=None,
......@@ -167,14 +176,14 @@ def run_customized_training_loop(
raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.')
total_training_steps = steps_per_epoch * epochs
# To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task.
with tf.device(get_primary_cpu_task(use_remote_tpu)):
train_iterator = _get_input_iterator(train_input_fn, strategy)
with strategy.scope():
total_training_steps = steps_per_epoch * epochs
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model, sub_model = model_fn()
......@@ -200,6 +209,17 @@ def run_customized_training_loop(
eval_metric.__class__.from_config(eval_metric.get_config())
if eval_metric else None)
# Create summary writers
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/train'))
else:
train_summary_writer = None
def _replicated_step(inputs):
"""Replicated training step."""
......@@ -262,8 +282,13 @@ def run_customized_training_loop(
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
eval_metric_value = _float_metric_value(eval_metric)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
_float_metric_value(eval_metric))
eval_metric_value)
with eval_summary_writer.as_default():
tf.summary.scalar(
eval_metric.name, eval_metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
......@@ -314,15 +339,26 @@ def run_customized_training_loop(
_run_callbacks_on_batch_end(current_step)
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps,
_float_metric_value(train_loss_metric))
current_step, total_training_steps, train_loss)
if train_metric:
training_status += ' training metric = %s' % _float_metric_value(
train_metric)
train_metric_value = _float_metric_value(train_metric)
training_status += ' training metric = %f' % train_metric_value
else:
train_metric_value = None
logging.info(training_status)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
if train_metric_value:
tf.summary.scalar(
train_metric.name, train_metric_value, step=current_step)
train_summary_writer.flush()
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
......@@ -355,9 +391,6 @@ def run_customized_training_loop(
train_metric)
training_summary['eval_metrics'] = _float_metric_value(eval_metric)
summary_path = os.path.join(model_dir, SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
_write_txt_summary(training_summary, model_dir)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment