Commit 134c9508 authored by Simon Kornblith's avatar Simon Kornblith Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323493207
parent 8b8524c7
......@@ -16,8 +16,9 @@
"""A light weight utilities to train TF2 models."""
import time
from typing import Callable, Optional, Text, Union
from typing import Callable, Dict, Optional, Text, Union
from absl import logging
import numpy as np
from orbit import runner
from orbit import utils
......@@ -177,7 +178,7 @@ class Controller:
if checkpoint_at_completion:
self.save_checkpoint()
def evaluate(self, steps: int = None):
def evaluate(self, steps: int = None) -> Optional[Dict[Text, np.number]]:
"""Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps`
......@@ -186,10 +187,12 @@ class Controller:
Args:
steps: The number of steps to evaluate for.
Returns:
The evaluation results as a dictionary of numpy values.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided.
"""
if self.evaluator is None:
raise ValueError("`evaluator` must be provided to call `evaluate()` "
......@@ -217,6 +220,8 @@ class Controller:
self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush()
return eval_outputs
def restore_checkpoint(self, checkpoint_path: Text = None):
"""Restore or initialize the model.
......
......@@ -329,7 +329,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=2)
eval_results = test_controller.evaluate(steps=2)
# Only eval summaries are written
self.assertFalse(
......@@ -339,6 +339,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
self.assertIn("eval_loss", eval_results)
# Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done")
......
......@@ -378,7 +378,7 @@ def get_value(x) -> np.ndarray:
x: input variable.
Returns:
A Numpy array.
A Numpy array or number.
"""
if not tf.is_tensor(x):
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment