Commit 690e44ed authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 347505505
parent 1ee1969c
......@@ -283,8 +283,8 @@ def run_customized_training_loop(
logging.info(
'Checkpoint file %s found and restoring from '
'initial checkpoint for core model.', init_checkpoint)
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
checkpoint.read(init_checkpoint).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
......
......@@ -220,8 +220,8 @@ def run_keras_compile_fit(model_dir,
optimizer = bert_model.optimizer
if init_checkpoint:
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
checkpoint.read(init_checkpoint).assert_existing_objects_matched()
if not isinstance(metric_fn, (list, tuple)):
metric_fn = [metric_fn]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment