Commit 89599a23 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 306664983
parent 8931d298
...@@ -389,9 +389,11 @@ def run_customized_training_loop( ...@@ -389,9 +389,11 @@ def run_customized_training_loop(
callback.on_batch_end(batch, logs) callback.on_batch_end(batch, logs)
# Training loop starts here. # Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) checkpoint = tf.train.Checkpoint(
model=model, optimizer=optimizer, global_step=optimizer.iterations)
sub_model_checkpoint = tf.train.Checkpoint( sub_model_checkpoint = tf.train.Checkpoint(
model=sub_model) if sub_model_export_name else None model=sub_model,
global_step=optimizer.iterations) if sub_model_export_name else None
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file: if latest_checkpoint_file:
......
...@@ -197,7 +197,7 @@ def run_keras_compile_fit(model_dir, ...@@ -197,7 +197,7 @@ def run_keras_compile_fit(model_dir,
with strategy.scope(): with strategy.scope():
training_dataset = train_input_fn() training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn() evaluation_dataset = eval_input_fn() if eval_input_fn else None
bert_model, sub_model = model_fn() bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer optimizer = bert_model.optimizer
...@@ -330,7 +330,8 @@ def run_bert(strategy, ...@@ -330,7 +330,8 @@ def run_bert(strategy,
input_meta_data, input_meta_data,
model_config, model_config,
train_input_fn=None, train_input_fn=None,
eval_input_fn=None): eval_input_fn=None,
init_checkpoint=None):
"""Run BERT training.""" """Run BERT training."""
if FLAGS.mode == 'export_only': if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
...@@ -377,7 +378,7 @@ def run_bert(strategy, ...@@ -377,7 +378,7 @@ def run_bert(strategy,
eval_steps, eval_steps,
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, init_checkpoint or FLAGS.init_checkpoint,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment