Commit 533d1e6b authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Add TimeHistory callback to BERT.

PiperOrigin-RevId: 298466825
parent 7152763a
......@@ -368,8 +368,8 @@ def run_customized_training_loop(
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
......
......@@ -69,6 +69,8 @@ def define_common_bert_flags():
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training.
flags_core.define_performance(
num_parallel_calls=False,
......
......@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs,
steps_per_epoch,
eval_steps,
custom_callbacks=None)
custom_callbacks=custom_callbacks)
# Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with '
......@@ -311,6 +311,15 @@ def run_bert(strategy,
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier(
strategy,
model_config,
......@@ -326,7 +335,8 @@ def run_bert(strategy,
train_input_fn,
eval_input_fn,
run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit)
use_keras_compile_fit=FLAGS.use_keras_compile_fit,
custom_callbacks=custom_callbacks)
if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
......
......@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None,
......@@ -94,7 +95,21 @@ def main(_):
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data)
......
......@@ -23,6 +23,14 @@ from absl import flags
from official.utils.flags._conventions import help_wrap
def define_log_steps():
flags.DEFINE_integer(
name="log_steps", default=100,
help="Frequency with which to log timing information with TimeHistory.")
return []
def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
"""Register benchmarking flags.
......@@ -52,11 +60,7 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
"human consumption, and does not have any impact within "
"the system."))
flags.DEFINE_integer(
name='log_steps', default=100,
help='For every log_steps, we log the timing information such as '
'examples per second. Besides, for every log_steps, we store the '
'timestamp of a batch end.')
define_log_steps()
if benchmark_log_dir:
flags.DEFINE_string(
......
......@@ -72,6 +72,7 @@ define_base = register_key_flags_in_core(_base.define_base)
# We have define_base_eager for compatibility, since it used to be a separate
# function from define_base.
define_base_eager = define_base
define_log_steps = register_key_flags_in_core(_benchmark.define_log_steps)
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image)
......
......@@ -23,8 +23,7 @@ import os
import time
from absl import logging
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
import tensorflow.compat.v2 as tf
from tensorflow.python import tf2
from tensorflow.python.profiler import profiler_v2 as profiler
......@@ -118,7 +117,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
logging.info(
"TimeHistory: %.2f examples/second between steps %d and %d",
'TimeHistory: %.2f examples/second between steps %d and %d',
examples_per_second, self.last_log_step, self.global_steps)
if self.summary_writer:
......@@ -209,8 +208,8 @@ def set_session_config(enable_eager=False,
if enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)
def get_config_proto_v1(enable_xla=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment