Commit f69ef1cd authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 386531154
parent 7c2ff1af
......@@ -142,14 +142,19 @@ class BestCheckpointExporter:
return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs,
global_step)
if write_logs:
self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
......@@ -180,7 +185,7 @@ class BestCheckpointExporter:
return True
return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step):
def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
......@@ -190,8 +195,6 @@ class BestCheckpointExporter:
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
self._get_checkpoint_manager(checkpoint).save()
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment