Commit b39958cf authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Add TimeHistory callback to BERT.

PiperOrigin-RevId: 300433601
parent 1792fb76
...@@ -44,6 +44,11 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback): ...@@ -44,6 +44,11 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
self.batch_start_times[batch] = time.time() self.batch_start_times[batch] = time.time()
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
# If there are multiple steps_per_loop, the end batch index will not be the
# same as the starting index. Use the last starting index instead.
if batch not in self.batch_start_times:
batch = max(self.batch_start_times.keys())
self.batch_stop_times[batch] = time.time() self.batch_stop_times[batch] = time.time()
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1): def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
......
...@@ -419,8 +419,8 @@ def run_customized_training_loop( ...@@ -419,8 +419,8 @@ def run_customized_training_loop(
train_steps(train_iterator, train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32)) tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
......
...@@ -77,6 +77,8 @@ def define_common_bert_flags(): ...@@ -77,6 +77,8 @@ def define_common_bert_flags():
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training. # Adds flags for mixed precision and multi-worker training.
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
......
...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy, ...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
eval_steps, eval_steps,
custom_callbacks=None) custom_callbacks=custom_callbacks)
# Use user-defined loop to start training. # Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with ' logging.info('Training using customized training loop TF 2.0 with '
...@@ -363,6 +363,15 @@ def run_bert(strategy, ...@@ -363,6 +363,15 @@ def run_bert(strategy,
if not strategy: if not strategy:
raise ValueError('Distribution strategy has not been specified.') raise ValueError('Distribution strategy has not been specified.')
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier( trained_model = run_bert_classifier(
strategy, strategy,
model_config, model_config,
...@@ -378,7 +387,8 @@ def run_bert(strategy, ...@@ -378,7 +387,8 @@ def run_bert(strategy,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit) use_keras_compile_fit=FLAGS.use_keras_compile_fit,
custom_callbacks=custom_callbacks)
if FLAGS.model_export_path: if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
......
...@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper ...@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None, flags.DEFINE_string('vocab_file', None,
...@@ -94,7 +95,21 @@ def main(_): ...@@ -94,7 +95,21 @@ def main(_):
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'): if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
if FLAGS.mode in ('predict', 'train_and_predict'): if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
......
...@@ -117,8 +117,9 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -117,8 +117,9 @@ class TimeHistory(tf.keras.callbacks.Callback):
self.timestamp_log.append(BatchTimestamp(self.global_steps, now)) self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
logging.info( logging.info(
'TimeHistory: %.2f examples/second between steps %d and %d', 'TimeHistory: %.2f seconds, %.2f examples/second between steps %d '
examples_per_second, self.last_log_step, self.global_steps) 'and %d', elapsed_time, examples_per_second, self.last_log_step,
self.global_steps)
if self.summary_writer: if self.summary_writer:
with self.summary_writer.as_default(): with self.summary_writer.as_default():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment