"tutorials/models/vscode:/vscode.git/clone" did not exist on "103444c58f66bd7a25407cb8fba6248d8cc976e5"
Commit b8014d55 authored by Simon Kornblith's avatar Simon Kornblith Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323493207
parent e9b70e67
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
"""A light weight utilities to train TF2 models.""" """A light weight utilities to train TF2 models."""
import time import time
from typing import Callable, Optional, Text, Union from typing import Callable, Dict, Optional, Text, Union
from absl import logging from absl import logging
import numpy as np
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
...@@ -177,7 +178,7 @@ class Controller: ...@@ -177,7 +178,7 @@ class Controller:
if checkpoint_at_completion: if checkpoint_at_completion:
self.save_checkpoint() self.save_checkpoint()
def evaluate(self, steps: int = None): def evaluate(self, steps: int = None) -> Optional[Dict[Text, np.number]]:
"""Runs evaluation. """Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps` This method calls the `evaluate` method on the Evaluator object for `steps`
...@@ -186,10 +187,12 @@ class Controller: ...@@ -186,10 +187,12 @@ class Controller:
Args: Args:
steps: The number of steps to evaluate for. steps: The number of steps to evaluate for.
Returns:
The evaluation results as a dictionary of numpy values.
Raises: Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`. ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided. ValueError: If `evaluator` is not provided.
""" """
if self.evaluator is None: if self.evaluator is None:
raise ValueError("`evaluator` must be provided to call `evaluate()` " raise ValueError("`evaluator` must be provided to call `evaluate()` "
...@@ -217,6 +220,8 @@ class Controller: ...@@ -217,6 +220,8 @@ class Controller:
self.eval_summary_manager.write_summaries(eval_outputs) self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush() self.eval_summary_manager.flush()
return eval_outputs
def restore_checkpoint(self, checkpoint_path: Text = None): def restore_checkpoint(self, checkpoint_path: Text = None):
"""Restore or initialize the model. """Restore or initialize the model.
......
...@@ -329,7 +329,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -329,7 +329,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=2) eval_results = test_controller.evaluate(steps=2)
# Only eval summaries are written # Only eval summaries are written
self.assertFalse( self.assertFalse(
...@@ -339,6 +339,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -339,6 +339,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEmpty( self.assertNotEmpty(
summaries_with_matching_keyword( summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
self.assertIn("eval_loss", eval_results)
# Tests continuous eval with timeout and timeout_fn. # Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done") done_file = os.path.join(self.model_dir, "summaries/eval/Done")
......
...@@ -378,7 +378,7 @@ def get_value(x) -> np.ndarray: ...@@ -378,7 +378,7 @@ def get_value(x) -> np.ndarray:
x: input variable. x: input variable.
Returns: Returns:
A Numpy array. A Numpy array or number.
""" """
if not tf.is_tensor(x): if not tf.is_tensor(x):
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment