Commit f69ef1cd authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 386531154
parent 7c2ff1af
...@@ -142,14 +142,19 @@ class BestCheckpointExporter: ...@@ -142,14 +142,19 @@ class BestCheckpointExporter:
return self._checkpoint_manager return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step): def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d', logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step) eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better( if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs): self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs, if write_logs:
global_step) self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self): def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path): if not tf.io.gfile.exists(self.best_ckpt_logs_path):
...@@ -180,7 +185,7 @@ class BestCheckpointExporter: ...@@ -180,7 +185,7 @@ class BestCheckpointExporter:
return True return True
return False return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step): def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file.""" """Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs) eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step eval_logs_ext['best_ckpt_global_step'] = global_step
...@@ -190,8 +195,6 @@ class BestCheckpointExporter: ...@@ -190,8 +195,6 @@ class BestCheckpointExporter:
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer: with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n') writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
self._get_checkpoint_manager(checkpoint).save()
@property @property
def best_ckpt_logs(self): def best_ckpt_logs(self):
return self._best_ckpt_logs return self._best_ckpt_logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment