Unverified Commit f2eb1701 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#6992)

252522861  by hongkuny<hongkuny@google.com>:

    Remove export using trained model due to implementation error

--
252156812  by yuefengz<yuefengz@google.com>:

    Fix the callback method name in BERT: replaced on_batch_start with on_batch_begin. Without the fix, it won't work with Keras callbacks.

--
251782065  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change

PiperOrigin-RevId: 252522861
parent f7a44074
......@@ -40,7 +40,7 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
self.timer_records = []
self.start_time = None
def on_batch_start(self, batch, logs=None):
def on_batch_begin(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.start_time = time.time()
......
......@@ -89,7 +89,7 @@ def run_customized_training_loop(
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
as an optimization.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_start()`, `on_batch_end()`,
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
Returns:
......@@ -203,12 +203,12 @@ def run_customized_training_loop(
metric_result)
return metric_result
def _run_callbacks_on_batch_start(batch):
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_start(batch)
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
......@@ -235,7 +235,7 @@ def run_customized_training_loop(
train_loss = None
while current_step < total_training_steps:
current_step += 1
_run_callbacks_on_batch_start(current_step)
_run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator).numpy().astype(float)
if train_metric:
......
......@@ -204,7 +204,7 @@ def run_bert(strategy, input_meta_data):
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training(
run_customized_training(
strategy,
bert_config,
input_meta_data,
......@@ -217,10 +217,6 @@ def run_bert(strategy, input_meta_data):
FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
def main(_):
# Users should always run this script under TF 2.x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment