"vscode:/vscode.git/clone" did not exist on "058d3061f772cb85997059b39e476dca5074c29f"
Commit c508968c authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 357078424
parent b9b0be18
......@@ -56,6 +56,20 @@ class BestCheckpointExporter:
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
self._checkpoint_manager = None
def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or (
self._checkpoint_manager.checkpoint != checkpoint):
logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=self._export_dir,
max_to_keep=1,
checkpoint_name='best_ckpt')
return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
......@@ -105,10 +119,7 @@ class BestCheckpointExporter:
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
# Saving the best checkpoint might be interrupted if the job got killed.
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.remove(file_to_remove)
checkpoint.write(self.best_ckpt_path)
self._get_checkpoint_manager(checkpoint).save()
@property
def best_ckpt_logs(self):
......@@ -120,7 +131,8 @@ class BestCheckpointExporter:
@property
def best_ckpt_path(self):
return os.path.join(self._export_dir, 'best_ckpt')
"""Returns the best ckpt path or None if there is no ckpt yet."""
return tf.train.latest_checkpoint(self._export_dir)
@gin.configurable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment