"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "64551a69186d28db1f499ba373f1b19c6a7ed894"
Unverified Commit dc8c6ce1 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7209)

257883986  by hongkuny<hongkuny@google.com>:

    Adds tf.summary for bert training

--

PiperOrigin-RevId: 257883986
parent fe748d4a
...@@ -24,7 +24,8 @@ import os ...@@ -24,7 +24,8 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
SUMMARY_TXT = 'training_summary.txt' _SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10
def get_primary_cpu_task(use_remote_tpu=False): def get_primary_cpu_task(use_remote_tpu=False):
...@@ -76,6 +77,14 @@ def _steps_to_run(current_step, steps_per_epoch, steps_per_loop): ...@@ -76,6 +77,14 @@ def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
return steps_per_loop return steps_per_loop
def _write_txt_summary(training_summary, model_dir):
"""Writes a summary text file to record stats."""
summary_path = os.path.join(model_dir, _SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
def run_customized_training_loop( def run_customized_training_loop(
# pylint: disable=invalid-name # pylint: disable=invalid-name
_sentinel=None, _sentinel=None,
...@@ -167,14 +176,14 @@ def run_customized_training_loop( ...@@ -167,14 +176,14 @@ def run_customized_training_loop(
raise ValueError( raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.') 'if `metric_fn` is specified, metric_fn must be a callable.')
total_training_steps = steps_per_epoch * epochs
# To reduce unnecessary send/receive input pipeline operation, we place input # To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task. # pipeline ops in worker task.
with tf.device(get_primary_cpu_task(use_remote_tpu)): with tf.device(get_primary_cpu_task(use_remote_tpu)):
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
with strategy.scope(): with strategy.scope():
total_training_steps = steps_per_epoch * epochs
# To correctly place the model weights on accelerators, # To correctly place the model weights on accelerators,
# model and optimizer should be created in scope. # model and optimizer should be created in scope.
model, sub_model = model_fn() model, sub_model = model_fn()
...@@ -200,6 +209,17 @@ def run_customized_training_loop( ...@@ -200,6 +209,17 @@ def run_customized_training_loop(
eval_metric.__class__.from_config(eval_metric.get_config()) eval_metric.__class__.from_config(eval_metric.get_config())
if eval_metric else None) if eval_metric else None)
# Create summary writers
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/train'))
else:
train_summary_writer = None
def _replicated_step(inputs): def _replicated_step(inputs):
"""Replicated training step.""" """Replicated training step."""
...@@ -262,8 +282,13 @@ def run_customized_training_loop( ...@@ -262,8 +282,13 @@ def run_customized_training_loop(
"""Runs validation steps and aggregate metrics.""" """Runs validation steps and aggregate metrics."""
for _ in range(eval_steps): for _ in range(eval_steps):
test_step(test_iterator) test_step(test_iterator)
eval_metric_value = _float_metric_value(eval_metric)
logging.info('Step: [%d] Validation metric = %f', current_training_step, logging.info('Step: [%d] Validation metric = %f', current_training_step,
_float_metric_value(eval_metric)) eval_metric_value)
with eval_summary_writer.as_default():
tf.summary.scalar(
eval_metric.name, eval_metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch): def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step.""" """Runs custom callbacks at the start of every step."""
...@@ -314,15 +339,26 @@ def run_customized_training_loop( ...@@ -314,15 +339,26 @@ def run_customized_training_loop(
_run_callbacks_on_batch_end(current_step) _run_callbacks_on_batch_end(current_step)
current_step += steps current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, current_step, total_training_steps, train_loss)
_float_metric_value(train_loss_metric))
if train_metric: if train_metric:
training_status += ' training metric = %s' % _float_metric_value( train_metric_value = _float_metric_value(train_metric)
train_metric) training_status += ' training metric = %f' % train_metric_value
else:
train_metric_value = None
logging.info(training_status) logging.info(training_status)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
if train_metric_value:
tf.summary.scalar(
train_metric.name, train_metric_value, step=current_step)
train_summary_writer.flush()
# Saves model checkpoints and run validation steps at every epoch end. # Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last # To avoid repeated model saving, we do not save after the last
...@@ -355,9 +391,6 @@ def run_customized_training_loop( ...@@ -355,9 +391,6 @@ def run_customized_training_loop(
train_metric) train_metric)
training_summary['eval_metrics'] = _float_metric_value(eval_metric) training_summary['eval_metrics'] = _float_metric_value(eval_metric)
summary_path = os.path.join(model_dir, SUMMARY_TXT) _write_txt_summary(training_summary, model_dir)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment