Commit c013311e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[core] Use the full model instead of checkpoint_items when constructing...

[core] Use the full model instead of checkpoint_items when constructing Checkpoint in multitask evaluator.

PiperOrigin-RevId: 357962526
parent 581c29eb
......@@ -52,19 +52,9 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
# TODO(hongkuny): Define a more robust way to handle the training/eval
# checkpoint loading.
if hasattr(self.model, "checkpoint_items"):
# Each evaluation task can have different models and load a subset of
# components from the training checkpoint. This is assuming the
# checkpoint items are able to load the weights of the evaluation model.
checkpoint_items = self.model.checkpoint_items
else:
# This is assuming the evaluation model is exactly the training model.
checkpoint_items = dict(model=self.model)
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
**checkpoint_items)
model=self.model)
self._validation_losses = None
self._validation_metrics = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment