Commit 7d86c317 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add TimeHistory callback to BERT.

PiperOrigin-RevId: 299594839
parent 1b45a4a5
...@@ -395,8 +395,8 @@ def run_customized_training_loop( ...@@ -395,8 +395,8 @@ def run_customized_training_loop(
train_steps(train_iterator, train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32)) tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
......
...@@ -77,8 +77,6 @@ def define_common_bert_flags(): ...@@ -77,8 +77,6 @@ def define_common_bert_flags():
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training. # Adds flags for mixed precision and multi-worker training.
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
......
...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy, ...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
eval_steps, eval_steps,
custom_callbacks=custom_callbacks) custom_callbacks=None)
# Use user-defined loop to start training. # Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with ' logging.info('Training using customized training loop TF 2.0 with '
...@@ -363,15 +363,6 @@ def run_bert(strategy, ...@@ -363,15 +363,6 @@ def run_bert(strategy,
if not strategy: if not strategy:
raise ValueError('Distribution strategy has not been specified.') raise ValueError('Distribution strategy has not been specified.')
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier( trained_model = run_bert_classifier(
strategy, strategy,
model_config, model_config,
...@@ -387,8 +378,7 @@ def run_bert(strategy, ...@@ -387,8 +378,7 @@ def run_bert(strategy,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit, use_keras_compile_fit=FLAGS.use_keras_compile_fit)
custom_callbacks=custom_callbacks)
if FLAGS.model_export_path: if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
......
...@@ -29,7 +29,6 @@ from official.nlp.bert import run_squad_helper ...@@ -29,7 +29,6 @@ from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None, flags.DEFINE_string('vocab_file', None,
...@@ -95,21 +94,7 @@ def main(_): ...@@ -95,21 +94,7 @@ def main(_):
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'): if FLAGS.mode in ('train', 'train_and_predict'):
if FLAGS.log_steps: train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
if FLAGS.mode in ('predict', 'train_and_predict'): if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment