Commit 7d86c317 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add TimeHistory callback to BERT.

PiperOrigin-RevId: 299594839
parent 1b45a4a5
......@@ -395,8 +395,8 @@ def run_customized_training_loop(
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
......
......@@ -77,8 +77,6 @@ def define_common_bert_flags():
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training.
flags_core.define_performance(
num_parallel_calls=False,
......
......@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs,
steps_per_epoch,
eval_steps,
custom_callbacks=custom_callbacks)
custom_callbacks=None)
# Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with '
......@@ -363,15 +363,6 @@ def run_bert(strategy,
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier(
strategy,
model_config,
......@@ -387,8 +378,7 @@ def run_bert(strategy,
train_input_fn,
eval_input_fn,
run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit,
custom_callbacks=custom_callbacks)
use_keras_compile_fit=FLAGS.use_keras_compile_fit)
if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
......
......@@ -29,7 +29,6 @@ from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None,
......@@ -95,21 +94,7 @@ def main(_):
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment