"tests/vscode:/vscode.git/clone" did not exist on "9c7f7fc475eb5aa171adcdcae9e7b6dc1bd7034f"
Commit 89599a23 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 306664983
parent 8931d298
......@@ -389,9 +389,11 @@ def run_customized_training_loop(
callback.on_batch_end(batch, logs)
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint = tf.train.Checkpoint(
model=model, optimizer=optimizer, global_step=optimizer.iterations)
sub_model_checkpoint = tf.train.Checkpoint(
model=sub_model) if sub_model_export_name else None
model=sub_model,
global_step=optimizer.iterations) if sub_model_export_name else None
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
......
......@@ -197,7 +197,7 @@ def run_keras_compile_fit(model_dir,
with strategy.scope():
training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn()
evaluation_dataset = eval_input_fn() if eval_input_fn else None
bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer
......@@ -330,7 +330,8 @@ def run_bert(strategy,
input_meta_data,
model_config,
train_input_fn=None,
eval_input_fn=None):
eval_input_fn=None,
init_checkpoint=None):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
......@@ -377,7 +378,7 @@ def run_bert(strategy,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
init_checkpoint or FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
run_eagerly=FLAGS.run_eagerly,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment