"...git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "411e3e387dff931159af2b3caae53354456097cf"
Commit adb61343 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 311069693
parent 49b223b6
......@@ -137,6 +137,10 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
'benchmark_accuracy_8x8_tpu_bf16_seq128_1m_steps')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
# Set train_summary_interval to -1 to disable training summary, because
# writing summary to gcs may fail and summaries are not needed for this
# accuracy benchmark test.
FLAGS.train_summary_interval = -1
self._run_and_report_benchmark(summary_path=summary_path,
report_accuracy=True)
......
......@@ -89,6 +89,8 @@ def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
def write_txt_summary(training_summary, summary_dir):
"""Writes a summary text file to record stats."""
if not tf.io.gfile.exists(summary_dir):
tf.io.gfile.mkdir(summary_dir)
summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
......@@ -117,7 +119,8 @@ def run_customized_training_loop(
sub_model_export_name=None,
explicit_allreduce=False,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None):
post_allreduce_callbacks=None,
train_summary_interval=0):
"""Run BERT pretrain model training using low-level API.
Arguments:
......@@ -181,6 +184,8 @@ def run_customized_training_loop(
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks. Only used
when explicit_allreduce=True.
train_summary_interval: Step interval for training summaries. If the value
is a negative number, then training summaries are not enabled.
Returns:
Trained model.
......@@ -272,13 +277,14 @@ def run_customized_training_loop(
summary_dir = tempfile.mkdtemp()
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
last_summary_step = 0
if steps_per_loop >= _MIN_SUMMARY_STEPS and train_summary_interval >= 0:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'train'))
else:
train_summary_writer = None
train_summary_writer = tf.summary.create_noop_writer()
# Collects training variables.
training_vars = model.trainable_variables
......@@ -438,15 +444,20 @@ def run_customized_training_loop(
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
train_summary_writer.flush()
if current_step >= last_summary_step + train_summary_interval:
summary_writer = train_summary_writer
last_summary_step = current_step
else:
summary_writer = tf.summary.create_noop_writer()
with summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
summary_writer.flush()
logging.info(training_status)
if current_step % steps_per_epoch == 0:
......
......@@ -49,6 +49,9 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training '
'summaries. If the value is a negative number, '
'then training summaries are not enabled.')
common_flags.define_common_bert_flags()
......@@ -101,6 +104,7 @@ def run_customized_training(strategy,
input_files,
train_batch_size,
use_next_sentence_label=True,
train_summary_interval=0,
custom_callbacks=None):
"""Run BERT pretrain model training using low-level API."""
......@@ -135,6 +139,7 @@ def run_customized_training(strategy,
steps_per_loop=steps_per_loop,
epochs=epochs,
sub_model_export_name='pretrained/bert_model',
train_summary_interval=train_summary_interval,
custom_callbacks=custom_callbacks)
return trained_model
......@@ -170,6 +175,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
FLAGS.input_files,
FLAGS.train_batch_size,
FLAGS.use_next_sentence_label,
FLAGS.train_summary_interval,
custom_callbacks=custom_callbacks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment