Commit c508968c authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 357078424
parent b9b0be18
...@@ -56,6 +56,20 @@ class BestCheckpointExporter: ...@@ -56,6 +56,20 @@ class BestCheckpointExporter:
'higher, lower. Got: {}'.format(self._metric_comp)) 'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path)) tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric() self._best_ckpt_logs = self._maybe_load_best_eval_metric()
self._checkpoint_manager = None
def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or (
self._checkpoint_manager.checkpoint != checkpoint):
logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=self._export_dir,
max_to_keep=1,
checkpoint_name='best_ckpt')
return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step): def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d', logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
...@@ -105,10 +119,7 @@ class BestCheckpointExporter: ...@@ -105,10 +119,7 @@ class BestCheckpointExporter:
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer: with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n') writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
# Saving the best checkpoint might be interrupted if the job got killed. self._get_checkpoint_manager(checkpoint).save()
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.remove(file_to_remove)
checkpoint.write(self.best_ckpt_path)
@property @property
def best_ckpt_logs(self): def best_ckpt_logs(self):
...@@ -120,7 +131,8 @@ class BestCheckpointExporter: ...@@ -120,7 +131,8 @@ class BestCheckpointExporter:
@property @property
def best_ckpt_path(self): def best_ckpt_path(self):
return os.path.join(self._export_dir, 'best_ckpt') """Returns the best ckpt path or None if there is no ckpt yet."""
return tf.train.latest_checkpoint(self._export_dir)
@gin.configurable @gin.configurable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment