"mmdet3d/vscode:/vscode.git/clone" did not exist on "65e2074b30384c833ff6120c705656bf3c347ce1"
Commit c013311e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[core] Use the full model instead of checkpoint_items when constructing...

[core] Use the full model instead of checkpoint_items when constructing Checkpoint in multitask evaluator.

PiperOrigin-RevId: 357962526
parent 581c29eb
...@@ -52,19 +52,9 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -52,19 +52,9 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self._model = model self._model = model
self._global_step = global_step or orbit.utils.create_global_step() self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
# TODO(hongkuny): Define a more robust way to handle the training/eval
# checkpoint loading.
if hasattr(self.model, "checkpoint_items"):
# Each evaluation task can have different models and load a subset of
# components from the training checkpoint. This is assuming the
# checkpoint items are able to load the weights of the evaluation model.
checkpoint_items = self.model.checkpoint_items
else:
# This is assuming the evaluation model is exactly the training model.
checkpoint_items = dict(model=self.model)
self._checkpoint = tf.train.Checkpoint( self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, global_step=self.global_step,
**checkpoint_items) model=self.model)
self._validation_losses = None self._validation_losses = None
self._validation_metrics = None self._validation_metrics = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment