Commit ccc134f5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add checkpoint items loading to the evaluator to be consistent with trainer.

PiperOrigin-RevId: 389101303
parent 6a8107f6
......@@ -54,8 +54,15 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
if hasattr(self.model, "checkpoint_items"):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, model=self.model)
model=self.model,
global_step=self.global_step,
**checkpoint_items)
self._validation_losses = None
self._validation_metrics = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment