"scripts/ci/ci_install_dependency.sh" did not exist on "539df95d2cb5634b92d01ed83ed7c5c60d299a28"
Commit ccc134f5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add checkpoint items loading to the evaluator to be consistent with trainer.

PiperOrigin-RevId: 389101303
parent 6a8107f6
......@@ -54,8 +54,15 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
if hasattr(self.model, "checkpoint_items"):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, model=self.model)
model=self.model,
global_step=self.global_step,
**checkpoint_items)
self._validation_losses = None
self._validation_metrics = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment