"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "a576ea6086fd471f812aba6dfac70ed8aa321deb"
Commit 533d1e6b authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Add TimeHistory callback to BERT.

PiperOrigin-RevId: 298466825
parent 7152763a
...@@ -368,8 +368,8 @@ def run_customized_training_loop( ...@@ -368,8 +368,8 @@ def run_customized_training_loop(
train_steps(train_iterator, train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32)) tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
......
...@@ -69,6 +69,8 @@ def define_common_bert_flags(): ...@@ -69,6 +69,8 @@ def define_common_bert_flags():
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training. # Adds flags for mixed precision and multi-worker training.
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
......
...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy, ...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
eval_steps, eval_steps,
custom_callbacks=None) custom_callbacks=custom_callbacks)
# Use user-defined loop to start training. # Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with ' logging.info('Training using customized training loop TF 2.0 with '
...@@ -311,6 +311,15 @@ def run_bert(strategy, ...@@ -311,6 +311,15 @@ def run_bert(strategy,
if not strategy: if not strategy:
raise ValueError('Distribution strategy has not been specified.') raise ValueError('Distribution strategy has not been specified.')
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier( trained_model = run_bert_classifier(
strategy, strategy,
model_config, model_config,
...@@ -326,7 +335,8 @@ def run_bert(strategy, ...@@ -326,7 +335,8 @@ def run_bert(strategy,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit) use_keras_compile_fit=FLAGS.use_keras_compile_fit,
custom_callbacks=custom_callbacks)
if FLAGS.model_export_path: if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
......
...@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper ...@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None, flags.DEFINE_string('vocab_file', None,
...@@ -94,7 +95,21 @@ def main(_): ...@@ -94,7 +95,21 @@ def main(_):
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'): if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
if FLAGS.mode in ('predict', 'train_and_predict'): if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
......
...@@ -23,6 +23,14 @@ from absl import flags ...@@ -23,6 +23,14 @@ from absl import flags
from official.utils.flags._conventions import help_wrap from official.utils.flags._conventions import help_wrap
def define_log_steps():
flags.DEFINE_integer(
name="log_steps", default=100,
help="Frequency with which to log timing information with TimeHistory.")
return []
def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
"""Register benchmarking flags. """Register benchmarking flags.
...@@ -52,11 +60,7 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): ...@@ -52,11 +60,7 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
"human consumption, and does not have any impact within " "human consumption, and does not have any impact within "
"the system.")) "the system."))
flags.DEFINE_integer( define_log_steps()
name='log_steps', default=100,
help='For every log_steps, we log the timing information such as '
'examples per second. Besides, for every log_steps, we store the '
'timestamp of a batch end.')
if benchmark_log_dir: if benchmark_log_dir:
flags.DEFINE_string( flags.DEFINE_string(
......
...@@ -72,6 +72,7 @@ define_base = register_key_flags_in_core(_base.define_base) ...@@ -72,6 +72,7 @@ define_base = register_key_flags_in_core(_base.define_base)
# We have define_base_eager for compatibility, since it used to be a separate # We have define_base_eager for compatibility, since it used to be a separate
# function from define_base. # function from define_base.
define_base_eager = define_base define_base_eager = define_base
define_log_steps = register_key_flags_in_core(_benchmark.define_log_steps)
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark) define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device) define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image) define_image = register_key_flags_in_core(_misc.define_image)
......
...@@ -23,8 +23,7 @@ import os ...@@ -23,8 +23,7 @@ import os
import time import time
from absl import logging from absl import logging
import tensorflow as tf import tensorflow.compat.v2 as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.profiler import profiler_v2 as profiler
...@@ -118,7 +117,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -118,7 +117,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
self.timestamp_log.append(BatchTimestamp(self.global_steps, now)) self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
logging.info( logging.info(
"TimeHistory: %.2f examples/second between steps %d and %d", 'TimeHistory: %.2f examples/second between steps %d and %d',
examples_per_second, self.last_log_step, self.global_steps) examples_per_second, self.last_log_step, self.global_steps)
if self.summary_writer: if self.summary_writer:
...@@ -209,8 +208,8 @@ def set_session_config(enable_eager=False, ...@@ -209,8 +208,8 @@ def set_session_config(enable_eager=False,
if enable_eager: if enable_eager:
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
else: else:
sess = tf.Session(config=config) sess = tf.compat.v1.Session(config=config)
tf.keras.backend.set_session(sess) tf.compat.v1.keras.backend.set_session(sess)
def get_config_proto_v1(enable_xla=False): def get_config_proto_v1(enable_xla=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment