Commit adb61343 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 311069693
parent 49b223b6
...@@ -137,6 +137,10 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -137,6 +137,10 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
'benchmark_accuracy_8x8_tpu_bf16_seq128_1m_steps') 'benchmark_accuracy_8x8_tpu_bf16_seq128_1m_steps')
summary_path = os.path.join(FLAGS.model_dir, summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt') 'summaries/training_summary.txt')
# Set train_summary_interval to -1 to disable training summary, because
# writing summary to gcs may fail and summaries are not needed for this
# accuracy benchmark test.
FLAGS.train_summary_interval = -1
self._run_and_report_benchmark(summary_path=summary_path, self._run_and_report_benchmark(summary_path=summary_path,
report_accuracy=True) report_accuracy=True)
......
...@@ -89,6 +89,8 @@ def steps_to_run(current_step, steps_per_epoch, steps_per_loop): ...@@ -89,6 +89,8 @@ def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
def write_txt_summary(training_summary, summary_dir): def write_txt_summary(training_summary, summary_dir):
"""Writes a summary text file to record stats.""" """Writes a summary text file to record stats."""
if not tf.io.gfile.exists(summary_dir):
tf.io.gfile.mkdir(summary_dir)
summary_path = os.path.join(summary_dir, _SUMMARY_TXT) summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f: with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary)) logging.info('Training Summary: \n%s', str(training_summary))
...@@ -117,7 +119,8 @@ def run_customized_training_loop( ...@@ -117,7 +119,8 @@ def run_customized_training_loop(
sub_model_export_name=None, sub_model_export_name=None,
explicit_allreduce=False, explicit_allreduce=False,
pre_allreduce_callbacks=None, pre_allreduce_callbacks=None,
post_allreduce_callbacks=None): post_allreduce_callbacks=None,
train_summary_interval=0):
"""Run BERT pretrain model training using low-level API. """Run BERT pretrain model training using low-level API.
Arguments: Arguments:
...@@ -181,6 +184,8 @@ def run_customized_training_loop( ...@@ -181,6 +184,8 @@ def run_customized_training_loop(
functions will be invoked in the list order and right before gradients functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks. Only used are applied to variables for updates. Default is no callbacks. Only used
when explicit_allreduce=True. when explicit_allreduce=True.
train_summary_interval: Step interval for training summaries. If the value
is a negative number, then training summaries are not enabled.
Returns: Returns:
Trained model. Trained model.
...@@ -272,13 +277,14 @@ def run_customized_training_loop( ...@@ -272,13 +277,14 @@ def run_customized_training_loop(
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
eval_summary_writer = tf.summary.create_file_writer( eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval')) os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS: last_summary_step = 0
if steps_per_loop >= _MIN_SUMMARY_STEPS and train_summary_interval >= 0:
# Only writes summary when the stats are collected sufficiently over # Only writes summary when the stats are collected sufficiently over
# enough steps. # enough steps.
train_summary_writer = tf.summary.create_file_writer( train_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'train')) os.path.join(summary_dir, 'train'))
else: else:
train_summary_writer = None train_summary_writer = tf.summary.create_noop_writer()
# Collects training variables. # Collects training variables.
training_vars = model.trainable_variables training_vars = model.trainable_variables
...@@ -438,15 +444,20 @@ def run_customized_training_loop( ...@@ -438,15 +444,20 @@ def run_customized_training_loop(
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss) current_step, total_training_steps, train_loss)
if train_summary_writer: if current_step >= last_summary_step + train_summary_interval:
with train_summary_writer.as_default(): summary_writer = train_summary_writer
tf.summary.scalar( last_summary_step = current_step
train_loss_metric.name, train_loss, step=current_step) else:
for metric in train_metrics + model.metrics: summary_writer = tf.summary.create_noop_writer()
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value) with summary_writer.as_default():
tf.summary.scalar(metric.name, metric_value, step=current_step) tf.summary.scalar(
train_summary_writer.flush() train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
summary_writer.flush()
logging.info(training_status) logging.info(training_status)
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
......
...@@ -49,6 +49,9 @@ flags.DEFINE_float('warmup_steps', 10000, ...@@ -49,6 +49,9 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.') 'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True, flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.') 'Whether to use next sentence label to compute final loss.')
flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training '
'summaries. If the value is a negative number, '
'then training summaries are not enabled.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
...@@ -101,6 +104,7 @@ def run_customized_training(strategy, ...@@ -101,6 +104,7 @@ def run_customized_training(strategy,
input_files, input_files,
train_batch_size, train_batch_size,
use_next_sentence_label=True, use_next_sentence_label=True,
train_summary_interval=0,
custom_callbacks=None): custom_callbacks=None):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
...@@ -135,6 +139,7 @@ def run_customized_training(strategy, ...@@ -135,6 +139,7 @@ def run_customized_training(strategy,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
epochs=epochs, epochs=epochs,
sub_model_export_name='pretrained/bert_model', sub_model_export_name='pretrained/bert_model',
train_summary_interval=train_summary_interval,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
return trained_model return trained_model
...@@ -170,6 +175,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None): ...@@ -170,6 +175,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
FLAGS.input_files, FLAGS.input_files,
FLAGS.train_batch_size, FLAGS.train_batch_size,
FLAGS.use_next_sentence_label, FLAGS.use_next_sentence_label,
FLAGS.train_summary_interval,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment