Commit 690e44ed authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 347505505
parent 1ee1969c
...@@ -283,8 +283,8 @@ def run_customized_training_loop( ...@@ -283,8 +283,8 @@ def run_customized_training_loop(
logging.info( logging.info(
'Checkpoint file %s found and restoring from ' 'Checkpoint file %s found and restoring from '
'initial checkpoint for core model.', init_checkpoint) 'initial checkpoint for core model.', init_checkpoint)
checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() checkpoint.read(init_checkpoint).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
......
...@@ -220,8 +220,8 @@ def run_keras_compile_fit(model_dir, ...@@ -220,8 +220,8 @@ def run_keras_compile_fit(model_dir,
optimizer = bert_model.optimizer optimizer = bert_model.optimizer
if init_checkpoint: if init_checkpoint:
checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() checkpoint.read(init_checkpoint).assert_existing_objects_matched()
if not isinstance(metric_fn, (list, tuple)): if not isinstance(metric_fn, (list, tuple)):
metric_fn = [metric_fn] metric_fn = [metric_fn]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment