Unverified Commit f2eb1701 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#6992)

252522861  by hongkuny<hongkuny@google.com>:

    Remove export using trained model due to implementation error

--
252156812  by yuefengz<yuefengz@google.com>:

    Fix the callback method name in BERT: replaced on_batch_start with on_batch_begin. Without the fix, it won't work with Keras callbacks.

--
251782065  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change

PiperOrigin-RevId: 252522861
parent f7a44074
...@@ -40,7 +40,7 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback): ...@@ -40,7 +40,7 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
self.timer_records = [] self.timer_records = []
self.start_time = None self.start_time = None
def on_batch_start(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if batch < self.num_batches_to_skip: if batch < self.num_batches_to_skip:
return return
self.start_time = time.time() self.start_time = time.time()
......
...@@ -89,7 +89,7 @@ def run_customized_training_loop( ...@@ -89,7 +89,7 @@ def run_customized_training_loop(
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
as an optimization. as an optimization.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_start()`, `on_batch_end()`, training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training. methods are invoked during training.
Returns: Returns:
...@@ -203,12 +203,12 @@ def run_customized_training_loop( ...@@ -203,12 +203,12 @@ def run_customized_training_loop(
metric_result) metric_result)
return metric_result return metric_result
def _run_callbacks_on_batch_start(batch): def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step.""" """Runs custom callbacks at the start of every step."""
if not custom_callbacks: if not custom_callbacks:
return return
for callback in custom_callbacks: for callback in custom_callbacks:
callback.on_batch_start(batch) callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch): def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step.""" """Runs custom callbacks at the end of every step."""
...@@ -235,7 +235,7 @@ def run_customized_training_loop( ...@@ -235,7 +235,7 @@ def run_customized_training_loop(
train_loss = None train_loss = None
while current_step < total_training_steps: while current_step < total_training_steps:
current_step += 1 current_step += 1
_run_callbacks_on_batch_start(current_step) _run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator).numpy().astype(float) train_loss = train_step(train_iterator).numpy().astype(float)
if train_metric: if train_metric:
......
...@@ -204,7 +204,7 @@ def run_bert(strategy, input_meta_data): ...@@ -204,7 +204,7 @@ def run_bert(strategy, input_meta_data):
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu) use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training( run_customized_training(
strategy, strategy,
bert_config, bert_config,
input_meta_data, input_meta_data,
...@@ -217,10 +217,6 @@ def run_bert(strategy, input_meta_data): ...@@ -217,10 +217,6 @@ def run_bert(strategy, input_meta_data):
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu) use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment