Commit 31ca3b97 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

resovle merge conflicts

parents 3e9d886d 7fcd7cba
![TensorFlow Requirement: 2.x](https://img.shields.io/badge/TensorFlow%20Requirement-2.x-brightgreen)
# Orbit
Orbit is a customized training loop library built on top of Tensorflow 2. It
provides a flexible lightweight library that users can easily use or fork when
writing [customized training loop code](https://www.tensorflow.org/tutorials/distribute/custom_training)
in TF2. It intergates with `tf.distribute` seamlessly and supports running on
different device types (CPU, GPU, and TPU).
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Orbit package definition."""
from orbit import utils
from orbit.controller import Controller
from orbit.runner import *
from orbit.standard_runner import *
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,51 +15,54 @@ ...@@ -14,51 +15,54 @@
# ============================================================================== # ==============================================================================
"""A light weight utilities to train TF2 models.""" """A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import time import time
from typing import Callable, Optional, Text, Union
from absl import logging from absl import logging
from orbit import runner
from orbit import utils
import tensorflow as tf
import tensorflow.compat.v2 as tf def _log_info(message: Text):
from typing import Callable, Dict, Optional, Text """Logs `message` to the `info` log, and also prints to stdout."""
logging.info(message)
print(message)
from official.staging.training import utils
def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name: str):
if interval and steps_per_loop and (interval % steps_per_loop != 0):
raise ValueError("The {} interval ({}) must be a multiple "
"of the steps_per_loop ({})".format(
interval_name, interval, steps_per_loop))
class Controller(object):
class Controller:
"""Class that facilitates training and evaluation of models.""" """Class that facilitates training and evaluation of models."""
def __init__( def __init__(
self, self,
strategy: Optional[tf.distribute.Strategy] = None, strategy: Optional[tf.distribute.Strategy] = None,
train_fn: Optional[Callable[[tf.Tensor], trainer: Optional[runner.AbstractTrainer] = None,
Optional[Dict[Text, tf.Tensor]]]] = None, evaluator: Optional[runner.AbstractEvaluator] = None,
eval_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
global_step: Optional[tf.Variable] = None, global_step: Optional[tf.Variable] = None,
# Train related # Train related
train_steps: Optional[int] = None,
steps_per_loop: Optional[int] = None, steps_per_loop: Optional[int] = None,
summary_dir: Optional[Text] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None, checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# summary related # Summary related
summary_interval: Optional[int] = None, summary_interval: Optional[int] = None,
summary_dir: Optional[Text] = None,
# Evaluation related # Evaluation related
eval_summary_dir: Optional[Text] = None, eval_summary_dir: Optional[Text] = None):
eval_steps: Optional[int] = None,
eval_interval: Optional[int] = None):
"""Constructs a `Controller` instance. """Constructs a `Controller` instance.
Args: Args:
strategy: An instance of `tf.distribute.Strategy`. strategy: An instance of `tf.distribute.Strategy`.
train_fn: A callable defined as `def train_fn(num_steps)`, which trainer: An instance of `orbit.AbstractTrainer`, which represents model
`num_steps` indicates the number of steps to run for each loop. training details.
eval_fn: A callable defined as `def eval_fn(num_steps)`, which `num_steps` evaluator: An instance of `orbit.AbstractEvaluator`, which represents
indicates the number of steps for one evaluation. model evaluation details.
global_step: An integer `tf.Variable` indicating the global training step global_step: An integer `tf.Variable` indicating the global training step
number. Usually this can be obtained from `iterations` property of the number. Usually this can be obtained from `iterations` property of the
model's optimizer (e.g. `self.optimizer.iterations`), or users can model's optimizer (e.g. `self.optimizer.iterations`), or users can
...@@ -66,259 +70,328 @@ class Controller(object): ...@@ -66,259 +70,328 @@ class Controller(object):
own global step variable, it is recommended to create the `tf.Variable` own global step variable, it is recommended to create the `tf.Variable`
inside strategy scope, and with inside strategy scope, and with
`aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`. `aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
train_steps: The total (maximum) number of training steps to perform.
steps_per_loop: The number of steps to run in each "inner loop" of steps_per_loop: The number of steps to run in each "inner loop" of
training (passed to the `num_steps` parameter of `train_fn`). training (passed to the `num_steps` parameter of `trainer.train`).
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
checkpoint_manager: An instance of `tf.train.CheckpointManager`. checkpoint_manager: An instance of `tf.train.CheckpointManager`.
summary_interval: Step interval for training summaries. Note that this summary_interval: Step interval for training summaries. Note that this
argument only applies to the summaries outside the training loop. If the argument only applies to the summaries inside `trainer.train` function.
value is None, then training summaries are not enabled. Summaries outside like "steps_per_second" and outputs from
`trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
eval_summary_dir: The directory to write eval summaries. If None, it will eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`. be set to `summary_dir`.
eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip evaluation
in the middle of training. Note that evaluation only happens outside the
training loop, which the loop iteration is specify by `steps_per_loop`
parameter.
Raises: Raises:
ValueError: If both `train_fn` and `eval_fn` are None. ValueError: If both `trainer` and `evaluator` are None.
ValueError: If `train_fn` is not None and `train_steps` is None.
ValueError: If `steps_per_loop` is None when `train_fn` is provided.
ValueError: If `steps_per_loop` is not a positive integer. ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `summary_interval` is not a positive integer or it cannot
be divisible by `steps_per_loop`.
""" """
if train_fn is None and eval_fn is None: if trainer is None and evaluator is None:
raise ValueError("`train_fn` and `eval_fn` should not both be None") raise ValueError("`trainer` and `evaluator` should not both be None")
# TODO(rxsang): Support training until exhaustion by passing if trainer is not None:
# `train_steps=-1`. Currently it cannot be supported with a host training
# loop because break statements are not supported with distributed dataset.
if train_fn is not None:
if train_steps is None:
raise ValueError("`train_steps` is required when `train_fn` is "
"provided.")
if steps_per_loop is None: if steps_per_loop is None:
raise ValueError("`steps_per_loop` is required when `train_fn is " raise ValueError("`steps_per_loop` is required when `trainer` is "
"provided.") "provided.")
if not isinstance(steps_per_loop, int) or steps_per_loop < 1: if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
raise ValueError("`steps_per_loop` should be a positive integer") raise ValueError("`steps_per_loop` should be a positive integer")
if summary_interval is not None and summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0") if summary_interval is not None:
if summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0")
_validate_interval(
summary_interval, steps_per_loop, interval_name="summary")
self.trainer = trainer
self.evaluator = evaluator
self.strategy = strategy or tf.distribute.get_strategy() self.strategy = strategy or tf.distribute.get_strategy()
self.train_fn = train_fn
self.eval_fn = eval_fn
self.global_step = global_step self.global_step = global_step
self.checkpoint_manager = checkpoint_manager self.checkpoint_manager = checkpoint_manager
if self.train_fn is not None: if summary_dir is None and checkpoint_manager:
self.train_steps = train_steps summary_dir = checkpoint_manager.directory
self.steps_per_loop = steps_per_loop
if summary_dir:
self.summary_dir = summary_dir
elif checkpoint_manager:
self.summary_dir = checkpoint_manager.directory
else:
self.summary_dir = None
if self.trainer is not None:
self.step_timer = None
self.steps_per_loop = steps_per_loop
self.summary_interval = summary_interval self.summary_interval = summary_interval
if self.summary_dir and self.summary_interval:
summary_writer = tf.summary.create_file_writer(self.summary_dir)
else:
summary_writer = None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager( self.summary_manager = utils.SummaryManager(
summary_writer, summary_dir, tf.summary.scalar, global_step=self.global_step)
tf.summary.scalar,
global_step=self.global_step, eval_summary_writer = None
summary_interval=self.summary_interval) if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir
if self.eval_fn is not None: if eval_summary_dir == summary_dir and self.trainer is not None:
eval_summary_dir = eval_summary_dir or self.summary_dir # Reuse the summary writer if train and evaluation summary directory
eval_summary_writer = tf.summary.create_file_writer( # are the same.
eval_summary_dir) if eval_summary_dir else None self.eval_summary_manager = self.summary_manager
self.eval_summary_manager = utils.SummaryManager( else:
eval_summary_writer, tf.summary.scalar, global_step=self.global_step) self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)
self.eval_steps = eval_steps
self.eval_interval = eval_interval if self.global_step is not None:
# Creates and initializes the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy()) # pytype: disable=attribute-error
if self.global_step:
tf.summary.experimental.set_step(self.global_step) tf.summary.experimental.set_step(self.global_step)
# Restores the model if needed. # Restores the model if needed.
# TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None: if self.checkpoint_manager is not None:
model_restored = self._restore_model() checkpoint_interval = self.checkpoint_manager.checkpoint_interval
if not model_restored and self.checkpoint_manager.checkpoint_interval: _validate_interval(
# If the model is not restored from a checkpoint, save an initial checkpoint_interval, steps_per_loop, interval_name="checkpoint")
model_restored = self.restore_checkpoint()
if not model_restored and (checkpoint_interval and
self.trainer is not None):
# If the model is not restored from a checkpoint, and
# `checkpoint_interval` is enabled for training, save an initial
# checkpoint. # checkpoint.
ckpt_path = self.checkpoint_manager.save( self.save_checkpoint()
checkpoint_number=self.global_step)
logging.info("Saved checkpoins in %s", ckpt_path)
def _restore_model(self, checkpoint_path=None): def train(self, steps: int, checkpoint_at_completion: bool = True):
"""Restore or initialize the model. """Runs training.
This method calls the `train` method on the Trainable object until the
global step count is equal to `steps`. It will optionally save checkpoints,
if a CheckpointManager was passed to the Controller instance's `__init__`.
Args: Args:
checkpoint_path: An optional string indicates the checkpoint path to steps: The global step count to train up to.
restore. If None, will restore from `self.checkpoint_manager`. checkpoint_at_completion: Whether to save a checkpoint when this method
returns. Defaults to True (write the checkpoint). This is always
triggered, regardless of the checkpointing interval.
"""
if self.trainer is None:
raise ValueError("`self.trainer` is required when calling `train` "
"method.")
if self.global_step is None:
raise ValueError("`self.global_step` is required when calling `train` "
"method.")
# TODO(momernick): Support steps=None or -1 (training to exhaustion).
current_step = self.global_step.numpy() # This is an expensive access.
while current_step < steps:
logging.info("Train at step %s of %s", current_step, steps)
# Calculates steps to run for the next train loop.
num_steps = min(steps - current_step, self.steps_per_loop)
self._train_n_steps(num_steps)
self._maybe_save_checkpoint()
current_step = self.global_step.numpy() # This is an expensive access.
if checkpoint_at_completion:
self.save_checkpoint()
def evaluate(self, steps: int = None):
"""Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps`
steps, then writes the returned summaries (if any).
Args:
steps: The number of steps to evaluate for.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided.
Returns:
True if the latest checkpoint is found or restored. Otherwise False.
""" """
with self.strategy.scope(): if self.evaluator is None:
# Checkpoint restoring should be inside scope. b/139450638 raise ValueError("`evaluator` must be provided to call `evaluate()` "
if checkpoint_path is not None: "method.")
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
return True
return self.checkpoint_manager.restore_or_initialize()
def _evaluate_once(self, current_step): steps = steps or -1
"""Runs the evaluation once.""" current_step = self.global_step.numpy()
logging.info("Start evaluation at step: %s", current_step) if steps > 0:
logging.info("Running %s steps of evaluation at train step: %s", steps,
current_step)
steps = tf.convert_to_tensor(steps, dtype=tf.int32)
else:
logging.info("Evaluating at train step: %s", current_step)
with self.eval_summary_manager.summary_writer.as_default(): with self.eval_summary_manager.summary_writer.as_default():
eval_outputs = self.eval_fn(self.eval_steps) eval_outputs = self.evaluator.evaluate(steps)
if eval_outputs: if eval_outputs:
eval_outputs = tf.nest.map_structure(lambda x: x.numpy(), eval_outputs) eval_outputs = tf.nest.map_structure(utils.get_value, eval_outputs)
info = "step: {} evaluation metric: {}".format( info = "step: {} evaluation metric: {}".format(
current_step, eval_outputs) current_step, eval_outputs)
self._log_info(info) _log_info(info)
self.eval_summary_manager.write_summaries(eval_outputs) self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush() self.eval_summary_manager.flush()
def _maybe_save_checkpoints(self, current_step, force_trigger=False): def restore_checkpoint(self, checkpoint_path: Text = None):
if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval: """Restore or initialize the model.
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=not force_trigger) Args:
if ckpt_path is not None: checkpoint_path: An optional string indicates the checkpoint path to
logging.info("Saved checkpoins in %s", ckpt_path) restore. If None, will restore from `self.checkpoint_manager`.
Returns:
The path to the restored checkpoint if a restore happened, or None
if no restore occurred.
"""
with self.strategy.scope():
# Checkpoint restoring should be inside scope. b/139450638
if checkpoint_path is not None:
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
return checkpoint_path
return self.checkpoint_manager.restore_or_initialize()
def save_checkpoint(self):
"""Checkpoint the model.
def _maybe_evaluate(self, current_step, force_trigger=False): This method will write a checkpoint containing the current state of the
if self.eval_trigger(current_step, force_trigger): model.
self._evaluate_once(current_step)
def _log_info(self, message): Raises:
"""Logs `message` to the `info` log, and also prints to stdout.""" ValueError: if no CheckpointManager was provided to this Controller's
logging.info(message) init args.
print(message) """
self._maybe_save_checkpoint(force_trigger=True)
def train(self, evaluate=True): def train_and_evaluate(self,
"""Runs the training, with optional evaluation. train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
"""Train and evaluate in an interleaved manner.
This handles evaluation, gathering summaries, and saving checkpoints. This method will train the model until the global step count equals
`train_steps`, running an evaluation for `eval_steps` every `eval_interval`
training steps. In addition, this method will run a final evaluation at the
end of the training sequence.
Args: Args:
evaluate: A boolean indicates whether to perform evaluation during train_steps: The global step count to train up to.
training. eval_steps: The number of steps to run during an evaluation. If None,
this method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evalutions.
Must be a multiple of the controller's `steps_per_loop` init arg. If
None, evaluation will only be performed after training is complete.
Raises: Raises:
RuntimeError: If `global_step` is not updated correctly in `train_fn`. ValueError: If eval_interval is not a multiple of self.steps_per_loop.
""" """
if self.train_fn is None: _validate_interval(eval_interval, self.steps_per_loop, interval_name="eval")
raise ValueError("`self.train_fn` is required when calling `train` "
"method.") current_step = self.global_step.numpy() # This is an expensive access.
if self.global_step is None: eval_interval = eval_interval or (train_steps - current_step)
raise ValueError("`self.global_step` is required when calling `train` " while current_step < train_steps:
"method.") interval = min(train_steps - current_step, eval_interval)
if evaluate and self.eval_fn is None: num_steps = current_step + interval
raise ValueError("`self.eval_fn` is required when calling `train` method " self.train(steps=num_steps, checkpoint_at_completion=False)
"with `evaluate=True`") self.evaluate(steps=eval_steps)
current_step = self.global_step.numpy() # This is an expensive access.
step_timer = _StepTimer(self.global_step) self.save_checkpoint()
current_step = self.global_step.numpy()
logging.info("Train at step %s of %s", current_step, self.train_steps) def evaluate_continuously(self,
while current_step < self.train_steps: steps: int = None,
# Calculates steps to run for the next train loop. timeout: Optional[Union[int, float]] = None,
steps_per_loop = min(self.train_steps - current_step, self.steps_per_loop) timeout_fn: Optional[Callable[[], bool]] = None):
logging.info("Entering training loop with %s steps, at step %s of %s", """Monitor a directory and evaluate on checkpoints in it.
steps_per_loop, current_step, self.train_steps)
current_step += steps_per_loop This method continuously monitors a directory as specified by this
steps_per_loop = tf.convert_to_tensor(steps_per_loop, dtype=tf.int32) Controller's CheckpointManager init arg and runs evaluation on the
checkpoints found there.
with self.summary_manager.summary_writer.as_default():
train_outputs = self.train_fn(steps_per_loop)
# Updates and verifies the current step after a training loop finishes.
if current_step != self.global_step.numpy():
raise RuntimeError("`self.train_fn` is not updating `global_step` "
"correctly, expected: %s, actual: %s" %
(current_step, self.global_step.numpy()))
# Print information like metrics and steps_per_second after a training
# loop.
if train_outputs:
train_outputs = tf.nest.map_structure(
lambda x: x.numpy(), train_outputs)
steps_per_second = step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f} {}".format(
current_step, steps_per_second, train_outputs)
self._log_info(info)
train_outputs = train_outputs or {}
train_outputs["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_outputs)
self._maybe_save_checkpoints(current_step)
if evaluate:
self._maybe_evaluate(current_step)
self.summary_manager.write_summaries(train_outputs, always_write=True)
self.summary_manager.flush()
self._maybe_save_checkpoints(current_step, force_trigger=True)
if evaluate:
self._maybe_evaluate(current_step, force_trigger=True)
def evaluate(self, continuous=False, timeout_fn=None):
"""Runs the evaluation.
Args: Args:
continuous: If `True`, will continously monitor the checkpoint directory steps: The number of steps to run when evaluating.
to evaluate on the latest checkpoint. If `False`, will do the evaluation timeout: The maximum number of seconds to wait between checkpoints. See
once. tf.train.checkpoints_iterator documentation.
timeout_fn: Optional callable to call after a timeout. If the function timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated returns True, then it means that no new checkpoints will be generated
and the iterator will exit. and the iterator will exit.
Raises: Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`. ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` was not provided as a controller init arg.
"""
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory,
timeout=timeout,
timeout_fn=timeout_fn):
self.restore_checkpoint(checkpoint_path)
self.evaluate(steps)
def _train_n_steps(self, num_steps: int):
"""Run training for `num_steps`.
It will also write training outputs to summaries if there is any.
Args:
num_steps: An integer indicates how many steps to run for this training
loop.
Raises:
RuntimeError: If `global_step` is not updated correctly in
`trainer.train`.
"""
if not self.step_timer:
self.step_timer = StepTimer(self.global_step)
# Calculates steps to run for the next train loop.
current_step = self.global_step.numpy()
logging.info("Entering training loop at step %s to run %s steps",
current_step, num_steps)
current_step += num_steps
num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default():
# Create a lambda that returns true when summaries should be written.
should_record = False # Allows static optimization in no-summary cases.
if self.summary_interval:
should_record = lambda: (self.global_step % self.summary_interval == 0)
with tf.summary.record_if(should_record):
train_outputs = self.trainer.train(num_steps)
# Updates and verifies the current step after a training loop finishes.
if current_step != self.global_step.numpy():
raise RuntimeError("`trainer.train` function is not updating "
"`global_step` correctly, expected: %s, actual: %s" %
(current_step, self.global_step.numpy()))
# Print information like metrics and steps_per_second after a training
# loop.
if train_outputs:
train_outputs = tf.nest.map_structure(utils.get_value, train_outputs)
train_outputs = train_outputs or {}
steps_per_second = self.step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f} {}".format(
current_step, steps_per_second, train_outputs)
_log_info(info)
train_outputs["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_outputs)
def _maybe_save_checkpoint(self, force_trigger: bool = False):
"""Save checkpoints if necessary.
Args:
force_trigger: A boolean indicates whether to force saving checkpoints
regardless of the checkpoint interval.
Returns:
A boolean indicating whether a checkpoint was saved.
""" """
if self.eval_fn is None: if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
raise ValueError("`self.eval_fn` should not be None to call " ckpt_path = self.checkpoint_manager.save(
"`evaluate()` method.") checkpoint_number=self.global_step.numpy(),
check_interval=not force_trigger)
if not continuous and timeout_fn is not None: if ckpt_path is not None:
raise ValueError("`timeout_fn` can be only passed when `continuous` is " logging.info("Saved checkpoints in %s", ckpt_path)
"True") return True
return False
if continuous:
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory, timeout_fn=timeout_fn): class StepTimer:
self._restore_model(checkpoint_path)
self._evaluate_once(self.global_step.numpy())
return
latest_checkpoint = self.checkpoint_manager.latest_checkpoint
if not latest_checkpoint:
raise ValueError("no checkpoint found in dir %s" %
self.checkpoint_manager.directory)
self._restore_model()
self._evaluate_once(self.global_step.numpy())
class _StepTimer(object):
"""Utility class for measuring steps/second.""" """Utility class for measuring steps/second."""
def __init__(self, step): def __init__(self, step):
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,35 +13,16 @@ ...@@ -12,35 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for official.staging.training.controller.""" """Tests for orbit.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
from absl import logging
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf from orbit import controller
from orbit import standard_runner
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.staging.training import controller
from official.staging.training import standard_runnable
def all_strategy_combinations(): import tensorflow as tf
"""Gets combinations of distribution strategies."""
return combinations.combine(
strategy=[
strategy_combinations.one_device_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode="eager",
)
def create_model(): def create_model():
...@@ -57,7 +39,7 @@ def summaries_with_matching_keyword(keyword, summary_dir): ...@@ -57,7 +39,7 @@ def summaries_with_matching_keyword(keyword, summary_dir):
if event.summary is not None: if event.summary is not None:
for value in event.summary.value: for value in event.summary.value:
if keyword in value.tag: if keyword in value.tag:
tf.compat.v1.logging.error(event) logging.info(event)
yield event.summary yield event.summary
...@@ -69,30 +51,33 @@ def check_eventfile_for_keyword(keyword, summary_dir): ...@@ -69,30 +51,33 @@ def check_eventfile_for_keyword(keyword, summary_dir):
def dataset_fn(ctx): def dataset_fn(ctx):
del ctx del ctx
inputs = np.zeros((10, 3), dtype=np.float32) inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32) targets = np.ones((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)) dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100) dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True) dataset = dataset.batch(10, drop_remainder=True)
return dataset return dataset
class TestRunnable(standard_runnable.StandardTrainable, class TestRunner(standard_runner.StandardTrainer,
standard_runnable.StandardEvaluable): standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model.""" """Implements the training and evaluation APIs for the test model."""
def __init__(self): def __init__(self, return_numpy=False):
standard_runnable.StandardTrainable.__init__(self)
standard_runnable.StandardEvaluable.__init__(self)
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.model = create_model() self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop() self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32) self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
self.return_numpy = return_numpy
def build_train_dataset(self): train_dataset = (
return self.strategy.experimental_distribute_datasets_from_function( self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
dataset_fn) )
eval_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
standard_runner.StandardTrainer.__init__(self, train_dataset)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def train_step(self, iterator): def train_step(self, iterator):
...@@ -101,7 +86,7 @@ class TestRunnable(standard_runnable.StandardTrainable, ...@@ -101,7 +86,7 @@ class TestRunnable(standard_runnable.StandardTrainable,
inputs, targets = inputs inputs, targets = inputs
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = self.model(inputs) outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets) loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
grads = tape.gradient(loss, self.model.variables) grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables)) self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss) self.train_loss.update_state(loss)
...@@ -109,8 +94,9 @@ class TestRunnable(standard_runnable.StandardTrainable, ...@@ -109,8 +94,9 @@ class TestRunnable(standard_runnable.StandardTrainable,
self.strategy.run(_replicated_step, args=(next(iterator),)) self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self): def train_loop_end(self):
train_loss = self.train_loss.result()
return { return {
"loss": self.train_loss.result(), "loss": train_loss.numpy() if self.return_numpy else train_loss,
} }
def build_eval_dataset(self): def build_eval_dataset(self):
...@@ -126,39 +112,110 @@ class TestRunnable(standard_runnable.StandardTrainable, ...@@ -126,39 +112,110 @@ class TestRunnable(standard_runnable.StandardTrainable,
"""Replicated evaluation step.""" """Replicated evaluation step."""
inputs, targets = inputs inputs, targets = inputs
outputs = self.model(inputs) outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets) loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
self.eval_loss.update_state(loss) self.eval_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),)) self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self): def eval_end(self):
eval_loss = self.eval_loss.result()
return {
"eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,
}
class TestEvaluator(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
eval_dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def eval_reduce(self, state, output):
state.append(output)
return state
def eval_begin(self):
return []
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
return loss
per_replica_losses = self.strategy.run(
_replicated_step, args=(next(iterator),))
mean_loss = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return mean_loss
def eval_end(self, outputs):
return { return {
"eval_loss": self.eval_loss.result(), "eval_loss": tf.reduce_mean(outputs),
} }
class TestTrainerWithSummaries(standard_runner.StandardTrainer):
"""A Trainer model with summaries for testing purposes."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
standard_runner.StandardTrainer.__init__(
self, train_dataset, use_tpu_summary_optimization=True)
def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
tf.summary.scalar("loss", loss)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
class ControllerTest(tf.test.TestCase, parameterized.TestCase): class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(ControllerTest, self).setUp() super().setUp()
self.model_dir = self.get_temp_dir() self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self): def test_no_checkpoint(self):
test_runnable = TestRunnable() test_runner = TestRunner()
# No checkpoint manager and no strategy. # No checkpoint manager and no strategy.
test_controller = controller.Controller( test_controller = controller.Controller(
train_fn=test_runnable.train, trainer=test_runner,
eval_fn=test_runnable.evaluate, evaluator=test_runner,
global_step=test_runnable.global_step, global_step=test_runner.global_step,
train_steps=10,
steps_per_loop=2, steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), test_controller.train_and_evaluate(
eval_steps=2, train_steps=10, eval_steps=2, eval_interval=6)
eval_interval=5) self.assertEqual(test_runner.global_step, 10)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
# Loss and accuracy values should be written into summaries. # Loss and accuracy values should be written into summaries.
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
...@@ -171,51 +228,46 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -171,51 +228,46 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword( check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0. # No checkpoint, so global step starts from 0.
test_runnable.global_step.assign(0) test_runner.global_step.assign(0)
test_controller.train(evaluate=True) test_controller.train_and_evaluate(
self.assertEqual(test_runnable.global_step.numpy(), 10) train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
def test_no_checkpoint_and_summaries(self): def test_no_checkpoint_and_summaries(self):
test_runnable = TestRunnable() test_runner = TestRunner()
# No checkpoint + summary directories. # No checkpoint + summary directories.
test_controller = controller.Controller( test_controller = controller.Controller(
train_fn=test_runnable.train, trainer=test_runner,
eval_fn=test_runnable.evaluate, evaluator=test_runner,
global_step=test_runnable.global_step, global_step=test_runner.global_step,
train_steps=10, steps_per_loop=2)
steps_per_loop=2, test_controller.train_and_evaluate(
eval_steps=2, train_steps=10, eval_steps=2, eval_interval=6)
eval_interval=5) self.assertEqual(test_runner.global_step, 10)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10) @parameterized.named_parameters(("return_numpy", True),
("return_tensor", False))
@combinations.generate(all_strategy_combinations()) def test_train_and_evaluate(self, return_numpy):
def test_train_and_evaluate(self, strategy): test_runner = TestRunner(return_numpy=return_numpy)
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer) model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint,
self.model_dir, self.model_dir,
max_to_keep=None, max_to_keep=None,
step_counter=test_runnable.global_step, step_counter=test_runner.global_step,
checkpoint_interval=10) checkpoint_interval=10)
test_controller = controller.Controller( test_controller = controller.Controller(
strategy=strategy, trainer=test_runner,
train_fn=test_runnable.train, evaluator=test_runner,
eval_fn=test_runnable.evaluate, global_step=test_runner.global_step,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2, steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
eval_steps=2, test_controller.train_and_evaluate(
eval_interval=5) train_steps=10, eval_steps=2, eval_interval=6)
test_controller.train(evaluate=True)
# Checkpoints are saved. # Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
...@@ -232,31 +284,26 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -232,31 +284,26 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword( check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations()) def test_train_only(self):
def test_train_only(self, strategy): test_runner = TestRunner()
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer) model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint,
self.model_dir, self.model_dir,
max_to_keep=None, max_to_keep=None,
step_counter=test_runnable.global_step, step_counter=test_runner.global_step,
checkpoint_interval=10) checkpoint_interval=10)
test_controller = controller.Controller( test_controller = controller.Controller(
strategy=strategy, trainer=test_runner,
train_fn=test_runnable.train, global_step=test_runner.global_step,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2, steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
) )
test_controller.train(evaluate=False) test_controller.train(steps=10)
# Checkpoints are saved. # Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
...@@ -270,29 +317,23 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -270,29 +317,23 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertFalse( self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations()) def test_evaluate_only(self):
def test_evaluate_only(self, strategy): test_runner = TestRunner()
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint(model=test_runnable.model) checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt")) checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint,
self.model_dir, self.model_dir,
max_to_keep=None, max_to_keep=None,
step_counter=test_runnable.global_step) step_counter=test_runner.global_step)
test_controller = controller.Controller( test_controller = controller.Controller(
strategy=strategy, evaluator=test_runner,
eval_fn=test_runnable.evaluate, global_step=test_runner.global_step,
global_step=test_runnable.global_step,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
eval_steps=2, test_controller.evaluate(steps=2)
eval_interval=5)
test_controller.evaluate()
# Only eval summaries are written # Only eval summaries are written
self.assertFalse( self.assertFalse(
...@@ -303,6 +344,207 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -303,6 +344,207 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword( check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done")
def timeout_fn():
with tf.io.gfile.GFile(done_file, "w") as f:
f.write("DONE")
return True
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate_continuously(
timeout=1, timeout_fn=timeout_fn, steps=2)
self.assertNotEmpty(tf.io.gfile.glob(done_file))
def test_no_eval_steps(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager)
test_controller.evaluate()
def test_already_trained_model(self):
test_runner = TestRunner()
test_runner.global_step.assign(10)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
# `global_step` is already `train_steps`.
test_controller.train(steps=10)
def test_summaries_inside_train_fn(self):
test_runner = TestTrainerWithSummaries()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_with_same_summary_dir(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("loss",
os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("eval_loss",
os.path.join(self.model_dir, "summaries")))
def test_early_stop_on_eval_loss(self):
test_runner = TestRunner()
class EarlyStopController(controller.Controller):
"""A subclass of Controller supports early stopping."""
def train_and_evaluate(self,
train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
while self.global_step.numpy() < train_steps:
interval = min(train_steps - self.global_step.numpy(), eval_interval)
num_steps = self.global_step.numpy() + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
# Early stop condition.
if test_runner.eval_loss.result() < 0.1:
logging.info(
"Training early stopped as eval_loss %s is less than 0.1",
test_runner.eval_loss.result())
return
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = EarlyStopController(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=6, eval_interval=2)
self.assertLess(test_runner.global_step, 10)
def test_evaluate_with_loss_outputs(self):
test_evaluator = TestEvaluator()
checkpoint = tf.train.Checkpoint(model=test_evaluator.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.model_dir, max_to_keep=None)
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=5)
# Only eval summaries are written
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_reset_datasets(self):
test_runner = TestRunner()
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
train_dataset = (
test_runner.strategy.experimental_distribute_datasets_from_function(
dataset_fn))
eval_dataset = (
test_runner.strategy.experimental_distribute_datasets_from_function(
dataset_fn))
test_runner.train_dataset = train_dataset
test_runner.eval_dataset = eval_dataset
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,19 +15,12 @@ ...@@ -14,19 +15,12 @@
# ============================================================================== # ==============================================================================
"""An abstraction that users can easily handle their custom training loops.""" """An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
import six
import tensorflow.compat.v2 as tf
from typing import Dict, Optional, Text from typing import Dict, Optional, Text
import tensorflow as tf
@six.add_metaclass(abc.ABCMeta) class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
class AbstractTrainable(tf.Module):
"""An abstract class defining the APIs required for training.""" """An abstract class defining the APIs required for training."""
@abc.abstractmethod @abc.abstractmethod
...@@ -50,14 +44,13 @@ class AbstractTrainable(tf.Module): ...@@ -50,14 +44,13 @@ class AbstractTrainable(tf.Module):
one update to model parameters, e.g. if training a GAN). one update to model parameters, e.g. if training a GAN).
Returns: Returns:
The function may return a dictionary of `Tensors`, which will be The function may return a dictionary of `Tensors` or numpy arrays, which
written to logs and as TensorBoard summaries. will be written to logs and as TensorBoard summaries.
""" """
pass pass
@six.add_metaclass(abc.ABCMeta) class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
class AbstractEvaluable(tf.Module):
"""An abstract class defining the APIs required for evaluation.""" """An abstract class defining the APIs required for evaluation."""
@abc.abstractmethod @abc.abstractmethod
...@@ -73,7 +66,7 @@ class AbstractEvaluable(tf.Module): ...@@ -73,7 +66,7 @@ class AbstractEvaluable(tf.Module):
is `None`. is `None`.
Returns: Returns:
The function may return a dictionary of `Tensors`, which will be The function may return a dictionary of `Tensors` or numpy arrays, which
written to logs and as TensorBoard summaries. will be written to logs and as TensorBoard summaries.
""" """
pass pass
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,67 +15,79 @@ ...@@ -14,67 +15,79 @@
# ============================================================================== # ==============================================================================
"""An abstraction that users can easily handle their custom training loops.""" """An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
import six from typing import Any, Dict, Optional, Text
import tensorflow.compat.v2 as tf from orbit import runner
from typing import Dict, Optional, Text from orbit import utils
import tensorflow as tf
from official.staging.training import runnable
from official.staging.training import utils
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs."""
@six.add_metaclass(abc.ABCMeta) def __init__(self,
class StandardTrainable(runnable.AbstractTrainable): train_dataset,
"""Implements the standard functionality of AbstractTrainable APIs.""" use_tf_while_loop=True,
use_tf_function=True,
use_tpu_summary_optimization=False):
"""Construct a `StandardTrainer` object.
def __init__(self, use_tf_while_loop=True, use_tf_function=True): Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
if use_tf_while_loop and not use_tf_function: if use_tf_while_loop and not use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` " raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported") "is not supported")
self.use_tf_while_loop = use_tf_while_loop if use_tpu_summary_optimization and not use_tf_while_loop:
self.use_tf_function = use_tf_function raise ValueError("`use_tpu_summary_optimization=True` and "
self.train_dataset = None "`use_tf_while_loop=False` is not supported")
self.train_iter = None self._use_tf_while_loop = use_tf_while_loop
self.train_loop_fn = None self._use_tf_function = use_tf_function
self._train_dataset = train_dataset
@abc.abstractmethod self._train_iter = None
def build_train_dataset(self): self._train_loop_fn = None
"""Builds the training datasets. self._use_tpu_summary_optimization = use_tpu_summary_optimization
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
"""
pass
def train(self, def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class.""" """See base class."""
if self.train_dataset is None: self.train_loop_begin()
# Build train input dataset
self.train_dataset = self.build_train_dataset() if self._train_iter is None:
self.train_iter = tf.nest.map_structure(iter, self.train_dataset) self._train_iter = tf.nest.map_structure(iter, self.train_dataset)
if self.train_loop_fn is None: if self._train_loop_fn is None:
train_fn = self.train_step train_fn = self.train_step
if self.use_tf_while_loop: if self._use_tf_while_loop:
self.train_loop_fn = utils.create_tf_while_loop_fn(train_fn) self._train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
if self._use_tpu_summary_optimization:
self._train_loop_fn = utils.train_function_with_summaries(
self._train_loop_fn)
else:
self._train_loop_fn = tf.function(self._train_loop_fn)
else: else:
if self.use_tf_function: if self._use_tf_function:
train_fn = tf.function(train_fn) train_fn = tf.function(train_fn)
self.train_loop_fn = utils.create_loop_fn(train_fn) self._train_loop_fn = utils.create_loop_fn(train_fn)
self.train_loop_begin() self._train_loop_fn(self._train_iter, num_steps)
self.train_loop_fn(self.train_iter, num_steps)
return self.train_loop_end() return self.train_loop_end()
def train_loop_begin(self): def train_loop_begin(self):
"""Called once at the beginning of the training loop. """Called once at the beginning of the training loop.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over multiple This is a good place to reset metrics that accumulate values over multiple
steps of training. steps of training.
""" """
...@@ -107,54 +120,74 @@ class StandardTrainable(runnable.AbstractTrainable): ...@@ -107,54 +120,74 @@ class StandardTrainable(runnable.AbstractTrainable):
""" """
pass pass
@property
def train_dataset(self):
"""Returns the train_dataset instance."""
return self._train_dataset
@six.add_metaclass(abc.ABCMeta) @train_dataset.setter
class StandardEvaluable(runnable.AbstractEvaluable): def train_dataset(self, train_dataset):
"""Implements the standard functionality of AbstractEvaluable APIs.""" """Set a new train dataset and replace with the existing one.
def __init__(self, use_tf_function=True): Any unfinished work in the previous dataset will be discarded.
self.eval_use_tf_function = use_tf_function
self.eval_dataset = None
self.eval_loop_fn = None
@abc.abstractmethod Args:
def build_eval_dataset(self): train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
"""Builds the evaluation datasets. DistributedDataset.
"""
self._train_dataset = train_dataset
self._train_iter = None
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset. class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True):
"""Construct a `StandardEvaluator` object.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, evaluation will run on pure eager mode.
""" """
pass self._eval_use_tf_function = use_tf_function
self._eval_dataset = eval_dataset
self._eval_loop_fn = None
def evaluate( def evaluate(
self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class.""" """See base class."""
if self.eval_dataset is None: outputs = self.eval_begin() # pylint: disable=assignment-from-no-return
# Build train input dataset
self.eval_dataset = self.build_eval_dataset()
if self.eval_loop_fn is None: eval_iter = tf.nest.map_structure(iter, self._eval_dataset)
if self._eval_loop_fn is None:
eval_fn = self.eval_step eval_fn = self.eval_step
if self.eval_use_tf_function: if self._eval_use_tf_function:
eval_fn = tf.function(eval_fn) eval_fn = tf.function(eval_fn)
self.eval_loop_fn = utils.create_loop_fn(eval_fn) self._eval_loop_fn = utils.create_loop_fn(eval_fn)
eval_iter = tf.nest.map_structure(iter, self.eval_dataset) outputs = self._eval_loop_fn(
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
if outputs is None:
return self.eval_end()
else:
return self.eval_end(outputs)
self.eval_begin() def eval_begin(self) -> Any:
self.eval_loop_fn(eval_iter, num_steps)
return self.eval_end()
def eval_begin(self):
"""Called once at the beginning of the evaluation. """Called once at the beginning of the evaluation.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over the entire This is a good place to reset metrics that accumulate values over the entire
evaluation. evaluation.
Returns:
An output which is passed as `state` argument into `eval_reduce` function.
""" """
pass pass
@abc.abstractmethod @abc.abstractmethod
def eval_step(self, iterator): def eval_step(self, iterator) -> Any:
"""Implements one step of evaluation. """Implements one step of evaluation.
What a "step" consists of is up to the implementer. If using distribution What a "step" consists of is up to the implementer. If using distribution
...@@ -165,17 +198,57 @@ class StandardEvaluable(runnable.AbstractEvaluable): ...@@ -165,17 +198,57 @@ class StandardEvaluable(runnable.AbstractEvaluable):
Args: Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator. DistributedIterator.
Returns:
An output which is passed as `step_outputs` argument into `eval_reduce`
function.
""" """
pass pass
def eval_end(self) -> Optional[Dict[Text, tf.Tensor]]: def eval_end(self, *args) -> Optional[Dict[Text, tf.Tensor]]:
"""Called at the end of the evaluation. """Called at the end of the evaluation.
This is a good place to get metric results. The value returned from this This is a good place to get metric results. The value returned from this
function will be returned as-is from the evaluate() method. function will be returned as-is from the evaluate() method.
Args:
*args: the outputs from `eval_reduce` for the last eval step.
Returns: Returns:
The function may return a dictionary of `Tensors`, which will be The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries. written to logs and as TensorBoard summaries.
""" """
pass pass
def eval_reduce(self, state=None, step_outputs=None) -> Any:
"""A function to do the reduction on the evaluation outputs per step.
This is useful for passing states throughout evaluation. E.g. it can be used
to maintain the output losses from all the evaluation steps, and compute the
mean loss in `eval_end` function.
Args:
state: A maintained state throughout the evaluation.
step_outputs: Outputs from the current evaluation step.
Returns:
An output which is passed as `state` argument into `eval_reduce` function
for the next step. After evaluation is finished, the output from last step
will be passed into `eval_end` function.
"""
pass
@property
def eval_dataset(self):
"""Returns the train_datase instance."""
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, eval_dataset):
"""Set a new eval dataset and replace with the existing one.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
"""
self._eval_dataset = eval_dataset
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.standard_runner."""
# pylint: disable=g-bad-import-order
from orbit import standard_runner
import tensorflow as tf
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for tests."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.global_step = tf.Variable(
0,
trainable=False,
dtype=tf.int64,
name='global_step',
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
standard_runner.StandardTrainer.__init__(self, train_dataset=None)
standard_runner.StandardEvaluator.__init__(self, eval_dataset=None)
def train_loop_begin(self):
self.train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
def train_step(self, iterator):
def _replicated_step(_):
self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
return self.global_step.numpy()
def eval_begin(self):
self.eval_dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def eval_step(self, iterator):
def _replicated_step(_):
self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
return self.global_step.numpy()
class StandardRunnerTest(tf.test.TestCase):
def test_train(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.train(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
def test_eval(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.evaluate(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
if __name__ == '__main__':
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,16 +15,13 @@ ...@@ -14,16 +15,13 @@
# ============================================================================== # ==============================================================================
"""Some layered modules/functions to help users writing custom training loop.""" """Some layered modules/functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
import contextlib
import functools
import inspect import inspect
import six
import tensorflow.compat.v2 as tf import numpy as np
import tensorflow as tf
def create_loop_fn(step_fn): def create_loop_fn(step_fn):
...@@ -79,7 +77,6 @@ def create_tf_while_loop_fn(step_fn): ...@@ -79,7 +77,6 @@ def create_tf_while_loop_fn(step_fn):
A callable defined as the `loop_fn` defination below. A callable defined as the `loop_fn` defination below.
""" """
@tf.function
def loop_fn(iterator, num_steps): def loop_fn(iterator, num_steps):
"""A loop function with multiple steps. """A loop function with multiple steps.
...@@ -130,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs): ...@@ -130,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling # names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling # `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`. # `dataset_or_fn`.
if six.PY3: argspec = inspect.getfullargspec(dataset_or_fn)
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn)
args_names = argspec.args args_names = argspec.args
if "input_context" in args_names: if "input_context" in args_names:
...@@ -144,96 +138,62 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs): ...@@ -144,96 +138,62 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return strategy.experimental_distribute_datasets_from_function(dataset_fn) return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager(object): class SummaryManager:
"""A class manages writing summaries.""" """A class manages writing summaries."""
def __init__(self, def __init__(self, summary_dir, summary_fn, global_step=None):
summary_writer,
summary_fn,
global_step=None,
summary_interval=None):
"""Construct a summary manager object. """Construct a summary manager object.
Args: Args:
summary_writer: A `tf.summary.SummaryWriter` instance for writing summary_dir: the directory to write summaries.
summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor, summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation. step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for checking the current global step global_step: A `tf.Variable` instance for the global step.
value, in case users want to save summaries every N steps.
summary_interval: An integer, indicates the minimum step interval between
two summaries.
""" """
if summary_writer is not None: self._enabled = (summary_dir is not None)
self._summary_writer = summary_writer self._summary_dir = summary_dir
self._enabled = True
else:
self._summary_writer = tf.summary.create_noop_writer()
self._enabled = False
self._summary_fn = summary_fn self._summary_fn = summary_fn
self._summary_writer = None
if global_step is None: if global_step is None:
self._global_step = tf.summary.experimental.get_step() self._global_step = tf.summary.experimental.get_step()
else: else:
self._global_step = global_step self._global_step = global_step
if summary_interval is not None:
if self._global_step is None:
raise ValueError("`summary_interval` is not None, but no `global_step` "
"can be obtained ")
self._last_summary_step = self._global_step.numpy()
self._summary_interval = summary_interval
@property
def summary_interval(self):
return self._summary_interval
@property @property
def summary_writer(self): def summary_writer(self):
"""Returns the underlying summary writer.""" """Returns the underlying summary writer."""
if self._summary_writer is not None:
return self._summary_writer
if self._enabled:
self._summary_writer = tf.summary.create_file_writer(self._summary_dir)
else:
self._summary_writer = tf.summary.create_noop_writer()
return self._summary_writer return self._summary_writer
def flush(self): def flush(self):
"""Flush the underlying summary writer.""" """Flush the underlying summary writer."""
if self._enabled: if self._enabled:
tf.summary.flush(self._summary_writer) tf.summary.flush(self.summary_writer)
def write_summaries(self, items, always_write=True): def write_summaries(self, items):
"""Write a bulk of summaries. """Write a bulk of summaries.
Args: Args:
items: a dictionary of `Tensors` for writing summaries. items: a dictionary of `Tensors` for writing summaries.
always_write: An optional boolean. If `True`, the manager will always
write summaries unless the summaries have been written for the same
step. Otherwise the manager will only write the summaries if the
interval between summaries are larger than `summary_interval`.
Returns:
A boolean indicates whether the summaries are written or not.
""" """
# TODO(rxsang): Support writing summaries with nested structure, so users # TODO(rxsang): Support writing summaries with nested structure, so users
# can split the summaries into different directories for nicer visualization # can split the summaries into different directories for nicer visualization
# in Tensorboard, like train and eval metrics. # in Tensorboard, like train and eval metrics.
if not self._enabled: if not self._enabled:
return False return
if self._summary_interval is not None:
current_step = self._global_step.numpy()
if current_step == self._last_summary_step:
return False
if not always_write and current_step < (self._last_summary_step +
self._summary_interval):
return False
self._last_summary_step = current_step
with self._summary_writer.as_default(): with self.summary_writer.as_default():
for name, tensor in items.items(): for name, tensor in items.items():
self._summary_fn(name, tensor, step=self._global_step) self._summary_fn(name, tensor, step=self._global_step)
return True
@six.add_metaclass(abc.ABCMeta) class Trigger(metaclass=abc.ABCMeta):
class Trigger(object):
"""An abstract class representing a "trigger" for some event.""" """An abstract class representing a "trigger" for some event."""
@abc.abstractmethod @abc.abstractmethod
...@@ -294,7 +254,7 @@ class IntervalTrigger(Trigger): ...@@ -294,7 +254,7 @@ class IntervalTrigger(Trigger):
self._last_trigger_value = 0 self._last_trigger_value = 0
class EpochHelper(object): class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop.""" """A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step): def __init__(self, epoch_steps, global_step):
...@@ -340,3 +300,86 @@ class EpochHelper(object): ...@@ -340,3 +300,86 @@ class EpochHelper(object):
@property @property
def current_epoch(self): def current_epoch(self):
return self._current_epoch return self._current_epoch
@contextlib.contextmanager
def _soft_device_placement():
"""Context manager for soft device placement, allowing summaries on CPU."""
original_setting = tf.config.get_soft_device_placement()
try:
tf.config.set_soft_device_placement(True)
yield
finally:
tf.config.set_soft_device_placement(original_setting)
def train_function_with_summaries(*args, **kwargs):
"""Utility function to support TPU summaries via multiple `tf.function`s.
This permits interleaving summaries inside TPU-compatible code, but without
any performance impact on steps that do not write summaries.
Usage is as a decorator, similar to `tf.function`, and any `tf.function`
arguments will be passed through if supplied:
@trainer.train_function_with_summaries
def train(self, num_steps):
...
The decorated function is assumed to be a loop method accepting a `num_steps`
parameter, as for instance would be called within the `Controller`'s outer
train loop. The implementation here assumes that `summary_frequency` is
divisible by `steps_per_loop`. The decorated method should accept two
arguments, `self` and `num_steps`.
Two `tf.function` versions of `train_fn` are created: one inside a summary
writer scope with soft device placement enabled (used on steps that require
summary writing), and one with no summary writer present and soft device
placement disabled (used on all other steps).
Args:
*args: Arguments to pass through to `tf.function`.
**kwargs: Keyword arguments to pass through to `tf.function`.
Returns:
If the first argument is a callable, returns the decorated callable.
Otherwise, returns a decorator.
"""
def decorator(train_fn):
# TODO(dhr): Validate the signature of train_fn?
train_fn_with_summaries = tf.function(train_fn, *args, **kwargs)
train_fn_without_summaries = tf.function(train_fn, *args, **kwargs)
@functools.wraps(train_fn)
def wrapper(self, num_steps):
if tf.summary.should_record_summaries():
with _soft_device_placement():
output = train_fn_with_summaries(self, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
with tf.summary.record_if(False):
output = train_fn_without_summaries(self, num_steps)
return output
return wrapper
if args and callable(args[0]):
train_fn, args = args[0], args[1:]
return decorator(train_fn)
return decorator
def get_value(x) -> np.ndarray:
"""Returns the value of a variable/tensor.
Args:
x: input variable.
Returns:
A Numpy array.
"""
if not tf.is_tensor(x):
return x
return x.numpy()
## Attention-based Extraction of Structured Information from Street View Imagery # Attention-based Extraction of Structured Information from Street View Imagery
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/attention-based-extraction-of-structured/optical-character-recognition-on-fsns-test)](https://paperswithcode.com/sota/optical-character-recognition-on-fsns-test?p=attention-based-extraction-of-structured) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/attention-based-extraction-of-structured/optical-character-recognition-on-fsns-test)](https://paperswithcode.com/sota/optical-character-recognition-on-fsns-test?p=attention-based-extraction-of-structured)
[![Paper](http://img.shields.io/badge/paper-arXiv.1704.03549-B3181B.svg)](https://arxiv.org/abs/1704.03549) [![Paper](http://img.shields.io/badge/paper-arXiv.1704.03549-B3181B.svg)](https://arxiv.org/abs/1704.03549)
...@@ -7,14 +7,20 @@ ...@@ -7,14 +7,20 @@
*A TensorFlow model for real-world image text extraction problems.* *A TensorFlow model for real-world image text extraction problems.*
This folder contains the code needed to train a new Attention OCR model on the This folder contains the code needed to train a new Attention OCR model on the
[FSNS dataset][FSNS] dataset to transcribe street names in France. You can [FSNS dataset][FSNS] to transcribe street names in France. You can also train the code on your own data.
also use it to train it on your own data.
More details can be found in our paper: More details can be found in our paper:
["Attention-based Extraction of Structured Information from Street View ["Attention-based Extraction of Structured Information from Street View
Imagery"](https://arxiv.org/abs/1704.03549) Imagery"](https://arxiv.org/abs/1704.03549)
## Description
* Paper presents a model based on ConvNets, RNN's and a novel attention mechanism.
Achieves **84.2%** on FSNS beating the previous benchmark (**72.46%**). Also studies
the speed/accuracy tradeoff that results from using CNN feature extractors of
different depths.
## Contacts ## Contacts
Authors Authors
...@@ -22,7 +28,18 @@ Authors ...@@ -22,7 +28,18 @@ Authors
* Zbigniew Wojna (zbigniewwojna@gmail.com) * Zbigniew Wojna (zbigniewwojna@gmail.com)
* Alexander Gorban (gorban@google.com) * Alexander Gorban (gorban@google.com)
Maintainer: Xavier Gibert [@xavigibert](https://github.com/xavigibert) Maintainer
* Xavier Gibert ([@xavigibert](https://github.com/xavigibert))
## Table of Contents
* [Requirements](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#requirements)
* [Dataset](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#dataset)
* [How to use this code](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-this-code)
* [Using your own image data](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#using-your-own-image-data)
* [How to use a pre-trained model](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-a-pre-trained-model)
* [Disclaimer](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#disclaimer)
## Requirements ## Requirements
...@@ -49,6 +66,42 @@ cd .. ...@@ -49,6 +66,42 @@ cd ..
[TF]: https://www.tensorflow.org/install/ [TF]: https://www.tensorflow.org/install/
[FSNS]: https://github.com/tensorflow/models/tree/master/research/street [FSNS]: https://github.com/tensorflow/models/tree/master/research/street
## Dataset
The French Street Name Signs (FSNS) dataset is split into subsets,
each of which is composed of multiple files. Note that these datasets
are very large. The approximate sizes are:
* Train: 512 files of 300MB each.
* Validation: 64 files of 40MB each.
* Test: 64 files of 50MB each.
* The datasets download includes a directory `testdata` that contains
some small datasets that are big enough to test that models can
actually learn something.
* Total: around 158GB
The download paths are in the following list:
```
https://download.tensorflow.org/data/fsns-20160927/charset_size=134.txt
https://download.tensorflow.org/data/fsns-20160927/test/test-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/test/test-00063-of-00064
https://download.tensorflow.org/data/fsns-20160927/testdata/arial-32-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/mnist-sample-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/numbers-16-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/train/train-00000-of-00512
...
https://download.tensorflow.org/data/fsns-20160927/train/train-00511-of-00512
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00063-of-00064
```
All URLs are stored in the [research/street](https://github.com/tensorflow/models/tree/master/research/street)
repository in the text file `python/fsns_urls.txt`.
## How to use this code ## How to use this code
To run all unit tests: To run all unit tests:
...@@ -80,7 +133,7 @@ tar xf attention_ocr_2017_08_09.tar.gz ...@@ -80,7 +133,7 @@ tar xf attention_ocr_2017_08_09.tar.gz
python train.py --checkpoint=model.ckpt-399731 python train.py --checkpoint=model.ckpt-399731
``` ```
## How to use your own image data to train the model ## Using your own image data
You need to define a new dataset. There are two options: You need to define a new dataset. There are two options:
...@@ -166,6 +219,14 @@ implement one in Python or C++. ...@@ -166,6 +219,14 @@ implement one in Python or C++.
The recommended way is to use the [Serving infrastructure][serving]. The recommended way is to use the [Serving infrastructure][serving].
To export to SavedModel format:
```
python model_export.py \
--checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
```
Alternatively you can: Alternatively you can:
1. define a placeholder for images (or use directly an numpy array) 1. define a placeholder for images (or use directly an numpy array)
2. [create a graph ](https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/eval.py#L60) 2. [create a graph ](https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/eval.py#L60)
...@@ -188,7 +249,7 @@ other than a one time experiment please use the [TensorFlow Serving][serving]. ...@@ -188,7 +249,7 @@ other than a one time experiment please use the [TensorFlow Serving][serving].
[1]: https://github.com/tensorflow/tensorflow/blob/aaf7adc/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py [1]: https://github.com/tensorflow/tensorflow/blob/aaf7adc/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
[2]: https://www.tensorflow.org/api_docs/python/tf/contrib/framework/assign_from_checkpoint_fn [2]: https://www.tensorflow.org/api_docs/python/tf/contrib/framework/assign_from_checkpoint_fn
[serving]: https://tensorflow.github.io/serving/serving_basic [serving]: https://www.tensorflow.org/tfx/serving/serving_basic
## Disclaimer ## Disclaimer
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# ============================================================================== # ==============================================================================
"""Define flags are common for both train.py and eval.py scripts.""" """Define flags are common for both train.py and eval.py scripts."""
import logging
import sys import sys
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
import logging
import datasets import datasets
import model import model
...@@ -35,9 +35,17 @@ logging.basicConfig( ...@@ -35,9 +35,17 @@ logging.basicConfig(
datefmt='%Y-%m-%d %H:%M:%S') datefmt='%Y-%m-%d %H:%M:%S')
_common_flags_defined = False
def define(): def define():
"""Define common flags.""" """Define common flags."""
# yapf: disable # yapf: disable
# common_flags.define() may be called multiple times in unit tests.
global _common_flags_defined
if _common_flags_defined:
return
_common_flags_defined = True
flags.DEFINE_integer('batch_size', 32, flags.DEFINE_integer('batch_size', 32,
'Batch size.') 'Batch size.')
...@@ -74,7 +82,7 @@ def define(): ...@@ -74,7 +82,7 @@ def define():
'the optimizer to use') 'the optimizer to use')
flags.DEFINE_float('momentum', 0.9, flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used') 'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True, flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation') 'If True will use image augmentation')
......
...@@ -56,14 +56,14 @@ def augment_image(image): ...@@ -56,14 +56,14 @@ def augment_image(image):
Returns: Returns:
Distorted Tensor image of the same shape. Distorted Tensor image of the same shape.
""" """
with tf.variable_scope('AugmentImage'): with tf.compat.v1.variable_scope('AugmentImage'):
height = image.get_shape().dims[0].value height = image.get_shape().dims[0].value
width = image.get_shape().dims[1].value width = image.get_shape().dims[1].value
# Random crop cut from the street sign image, resized to the same size. # Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image. # Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box( bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image), image_size=tf.shape(input=image),
bounding_boxes=tf.zeros([0, 0, 4]), bounding_boxes=tf.zeros([0, 0, 4]),
min_object_covered=0.8, min_object_covered=0.8,
aspect_ratio_range=[0.8, 1.2], aspect_ratio_range=[0.8, 1.2],
...@@ -74,7 +74,7 @@ def augment_image(image): ...@@ -74,7 +74,7 @@ def augment_image(image):
# Randomly chooses one of the 4 interpolation methods # Randomly chooses one of the 4 interpolation methods
distorted_image = inception_preprocessing.apply_with_random_selector( distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image, distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method), lambda x, method: tf.image.resize(x, [height, width], method),
num_cases=4) num_cases=4)
distorted_image.set_shape([height, width, 3]) distorted_image.set_shape([height, width, 3])
...@@ -99,9 +99,10 @@ def central_crop(image, crop_size): ...@@ -99,9 +99,10 @@ def central_crop(image, crop_size):
Returns: Returns:
A tensor of shape [crop_height, crop_width, channels]. A tensor of shape [crop_height, crop_width, channels].
""" """
with tf.variable_scope('CentralCrop'): with tf.compat.v1.variable_scope('CentralCrop'):
target_width, target_height = crop_size target_width, target_height = crop_size
image_height, image_width = tf.shape(image)[0], tf.shape(image)[1] image_height, image_width = tf.shape(
input=image)[0], tf.shape(input=image)[1]
assert_op1 = tf.Assert( assert_op1 = tf.Assert(
tf.greater_equal(image_height, target_height), tf.greater_equal(image_height, target_height),
['image_height < target_height', image_height, target_height]) ['image_height < target_height', image_height, target_height])
...@@ -129,7 +130,7 @@ def preprocess_image(image, augment=False, central_crop_size=None, ...@@ -129,7 +130,7 @@ def preprocess_image(image, augment=False, central_crop_size=None,
A float32 tensor of shape [H x W x 3] with RGB values in the required A float32 tensor of shape [H x W x 3] with RGB values in the required
range. range.
""" """
with tf.variable_scope('PreprocessImage'): with tf.compat.v1.variable_scope('PreprocessImage'):
image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if augment or central_crop_size: if augment or central_crop_size:
if num_towers == 1: if num_towers == 1:
...@@ -144,9 +145,6 @@ def preprocess_image(image, augment=False, central_crop_size=None, ...@@ -144,9 +145,6 @@ def preprocess_image(image, augment=False, central_crop_size=None,
images = [augment_image(img) for img in images] images = [augment_image(img) for img in images]
image = tf.concat(images, 1) image = tf.concat(images, 1)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.5)
return image return image
...@@ -185,7 +183,7 @@ def get_data(dataset, ...@@ -185,7 +183,7 @@ def get_data(dataset,
image_orig, augment, central_crop_size, num_towers=dataset.num_of_views) image_orig, augment, central_crop_size, num_towers=dataset.num_of_views)
label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes) label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes)
images, images_orig, labels, labels_one_hot = (tf.train.shuffle_batch( images, images_orig, labels, labels_one_hot = (tf.compat.v1.train.shuffle_batch(
[image, image_orig, label, label_one_hot], [image, image_orig, label, label_one_hot],
batch_size=batch_size, batch_size=batch_size,
num_threads=shuffle_config.num_batching_threads, num_threads=shuffle_config.num_batching_threads,
......
...@@ -72,7 +72,7 @@ def read_charset(filename, null_character=u'\u2591'): ...@@ -72,7 +72,7 @@ def read_charset(filename, null_character=u'\u2591'):
""" """
pattern = re.compile(r'(\d+)\t(.+)') pattern = re.compile(r'(\d+)\t(.+)')
charset = {} charset = {}
with tf.gfile.GFile(filename) as f: with tf.io.gfile.GFile(filename) as f:
for i, line in enumerate(f): for i, line in enumerate(f):
m = pattern.match(line) m = pattern.match(line)
if m is None: if m is None:
...@@ -96,9 +96,9 @@ class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler): ...@@ -96,9 +96,9 @@ class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler):
self._num_of_views = num_of_views self._num_of_views = num_of_views
def tensors_to_item(self, keys_to_tensors): def tensors_to_item(self, keys_to_tensors):
return tf.to_int64( return tf.cast(
self._num_of_views * keys_to_tensors[self._original_width_key] / self._num_of_views * keys_to_tensors[self._original_width_key] /
keys_to_tensors[self._width_key]) keys_to_tensors[self._width_key], dtype=tf.int64)
def get_split(split_name, dataset_dir=None, config=None): def get_split(split_name, dataset_dir=None, config=None):
...@@ -133,19 +133,19 @@ def get_split(split_name, dataset_dir=None, config=None): ...@@ -133,19 +133,19 @@ def get_split(split_name, dataset_dir=None, config=None):
zero = tf.zeros([1], dtype=tf.int64) zero = tf.zeros([1], dtype=tf.int64)
keys_to_features = { keys_to_features = {
'image/encoded': 'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''), tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/format': 'image/format':
tf.FixedLenFeature((), tf.string, default_value='png'), tf.io.FixedLenFeature((), tf.string, default_value='png'),
'image/width': 'image/width':
tf.FixedLenFeature([1], tf.int64, default_value=zero), tf.io.FixedLenFeature([1], tf.int64, default_value=zero),
'image/orig_width': 'image/orig_width':
tf.FixedLenFeature([1], tf.int64, default_value=zero), tf.io.FixedLenFeature([1], tf.int64, default_value=zero),
'image/class': 'image/class':
tf.FixedLenFeature([config['max_sequence_length']], tf.int64), tf.io.FixedLenFeature([config['max_sequence_length']], tf.int64),
'image/unpadded_class': 'image/unpadded_class':
tf.VarLenFeature(tf.int64), tf.io.VarLenFeature(tf.int64),
'image/text': 'image/text':
tf.FixedLenFeature([1], tf.string, default_value=''), tf.io.FixedLenFeature([1], tf.string, default_value=''),
} }
items_to_handlers = { items_to_handlers = {
'image': 'image':
...@@ -171,12 +171,14 @@ def get_split(split_name, dataset_dir=None, config=None): ...@@ -171,12 +171,14 @@ def get_split(split_name, dataset_dir=None, config=None):
config['splits'][split_name]['pattern']) config['splits'][split_name]['pattern'])
return slim.dataset.Dataset( return slim.dataset.Dataset(
data_sources=file_pattern, data_sources=file_pattern,
reader=tf.TFRecordReader, reader=tf.compat.v1.TFRecordReader,
decoder=decoder, decoder=decoder,
num_samples=config['splits'][split_name]['size'], num_samples=config['splits'][split_name]['size'],
items_to_descriptions=config['items_to_descriptions'], items_to_descriptions=config['items_to_descriptions'],
# additional parameters for convenience. # additional parameters for convenience.
charset=charset, charset=charset,
charset_file=charset_file,
image_shape=config['image_shape'],
num_char_classes=len(charset), num_char_classes=len(charset),
num_of_views=config['num_of_views'], num_of_views=config['num_of_views'],
max_sequence_length=config['max_sequence_length'], max_sequence_length=config['max_sequence_length'],
......
...@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase): ...@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase):
image_tf, label_tf = provider.get(['image', 'label']) image_tf, label_tf = provider.get(['image', 'label'])
with self.test_session() as sess: with self.test_session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
with slim.queues.QueueRunners(sess): with slim.queues.QueueRunners(sess):
image_np, label_np = sess.run([image_tf, label_tf]) image_np, label_np = sess.run([image_tf, label_tf])
......
...@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5 ...@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5
print('Downloading %s ...' % URL) print('Downloading %s ...' % URL)
urllib.request.urlretrieve(URL, DST_ORIG) urllib.request.urlretrieve(URL, DST_ORIG)
print('Writing %d records from %s to %s ...' % (KEEP_NUM_RECORDS, DST_ORIG, DST)) print('Writing %d records from %s to %s ...' %
(KEEP_NUM_RECORDS, DST_ORIG, DST))
with tf.io.TFRecordWriter(DST) as writer: with tf.io.TFRecordWriter(DST) as writer:
for raw_record in itertools.islice(tf.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS): for raw_record in itertools.islice(tf.compat.v1.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS):
writer.write(raw_record) writer.write(raw_record)
...@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name): ...@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for i in range(batch_size): for i in range(batch_size):
path = file_pattern % i path = file_pattern % i
print("Reading %s" % path) print("Reading %s" % path)
pil_image = PIL.Image.open(tf.gfile.GFile(path, 'rb')) pil_image = PIL.Image.open(tf.io.gfile.GFile(path, 'rb'))
images_actual_data[i, ...] = np.asarray(pil_image) images_actual_data[i, ...] = np.asarray(pil_image)
return images_actual_data return images_actual_data
...@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name): ...@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name) width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model( model = common_flags.create_model(
num_char_classes=dataset.num_char_classes, num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length, seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views, num_views=dataset.num_of_views,
null_code=dataset.null_code, null_code=dataset.null_code,
charset=dataset.charset) charset=dataset.charset)
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3]) raw_images = tf.compat.v1.placeholder(
tf.uint8, shape=[batch_size, height, width, 3])
images = tf.map_fn(data_provider.preprocess_image, raw_images, images = tf.map_fn(data_provider.preprocess_image, raw_images,
dtype=tf.float32) dtype=tf.float32)
endpoints = model.create_base(images, labels_one_hot=None) endpoints = model.create_base(images, labels_one_hot=None)
...@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern): ...@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_data = load_images(image_path_pattern, batch_size, images_data = load_images(image_path_pattern, batch_size,
dataset_name) dataset_name)
session_creator = monitored_session.ChiefSessionCreator( session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint) checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession( with monitored_session.MonitoredSession(
session_creator=session_creator) as sess: session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text, predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data}) feed_dict={images_placeholder: images_data})
return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()] return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()]
...@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern): ...@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
def main(_): def main(_):
print("Predicted strings:") print("Predicted strings:")
predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name, predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
FLAGS.image_path_pattern) FLAGS.image_path_pattern)
for line in predictions: for line in predictions:
print(line) print(line)
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() tf.compat.v1.app.run()
...@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase):
super(DemoInferenceTest, self).setUp() super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']: for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename), self.assertTrue(tf.io.gfile.exists(filename),
msg='Missing checkpoint file %s. ' msg='Missing checkpoint file %s. '
'Please download and extract it from %s' % 'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL)) (filename, _CHECKPOINT_URL))
self._batch_size = 32 self._batch_size = 32
tf.flags.FLAGS.dataset_dir = os.path.join(os.path.dirname(__file__), 'datasets/testdata/fsns') tf.flags.FLAGS.dataset_dir = os.path.join(
os.path.dirname(__file__), 'datasets/testdata/fsns')
def test_moving_variables_properly_loaded_from_a_checkpoint(self): def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32 batch_size = 32
...@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase):
images_data = demo_inference.load_images(image_path_pattern, batch_size, images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name) dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean' tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name( moving_mean_tf = tf.compat.v1.get_default_graph().get_tensor_by_name(
tensor_name + ':0') tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT) reader = tf.compat.v1.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name) moving_mean_expected = reader.get_tensor(tensor_name)
session_creator = monitored_session.ChiefSessionCreator( session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT) checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession( with monitored_session.MonitoredSession(
session_creator=session_creator) as sess: session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf, moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data}) feed_dict={images_placeholder: images_data})
...@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase):
'fsns', 'fsns',
image_path_pattern) image_path_pattern)
self.assertEqual([ self.assertEqual([
u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░', u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░', 'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░', 'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░', 'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░', 'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░', 'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░', 'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░', 'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░', 'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░', 'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░', 'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░', 'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░', 'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░', 'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░', 'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa' 'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░', 'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░', 'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░', 'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░', 'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░', 'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░', 'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░', 'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░', 'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░', 'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░', 'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░', 'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░', 'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░', 'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░', 'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░', 'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░' 'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
], predictions) ], predictions)
......
...@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None, ...@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None,
def main(_): def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir): if not tf.io.gfile.exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir) tf.io.gfile.makedirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes, model = common_flags.create_model(dataset.num_char_classes,
...@@ -62,7 +62,7 @@ def main(_): ...@@ -62,7 +62,7 @@ def main(_):
eval_ops = model.create_summaries( eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False) data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step() slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0}) session_config = tf.compat.v1.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop( slim.evaluation.evaluation_loop(
master=FLAGS.master, master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir, checkpoint_dir=FLAGS.train_log_dir,
......
...@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases): ...@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases):
The result of func(x, sel), where func receives the value of the The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically. selector as a python integer, but sel is sampled dynamically.
""" """
sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) sel = tf.random.uniform([], maxval=num_cases, dtype=tf.int32)
# Pass the real x only to one of the func calls. # Pass the real x only to one of the func calls.
return control_flow_ops.merge([ return control_flow_ops.merge([
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
...@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None): ...@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
Raises: Raises:
ValueError: if color_ordering not in [0, 3] ValueError: if color_ordering not in [0, 3]
""" """
with tf.name_scope(scope, 'distort_color', [image]): with tf.compat.v1.name_scope(scope, 'distort_color', [image]):
if fast_mode: if fast_mode:
if color_ordering == 0: if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.) image = tf.image.random_brightness(image, max_delta=32. / 255.)
...@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image, ...@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image,
Returns: Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox A tuple, a 3-D Tensor cropped_image and the distorted bbox
""" """
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): with tf.compat.v1.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
# Each bounding box has shape [1, num_boxes, box coords] and # Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax]. # the coordinates are ordered [ymin, xmin, ymax, xmax].
...@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image, ...@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image,
# bounding box. If no box is supplied, then we assume the bounding box is # bounding box. If no box is supplied, then we assume the bounding box is
# the entire image. # the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image), image_size=tf.shape(input=image),
bounding_boxes=bbox, bounding_boxes=bbox,
min_object_covered=min_object_covered, min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range, aspect_ratio_range=aspect_ratio_range,
...@@ -188,7 +188,7 @@ def preprocess_for_train(image, ...@@ -188,7 +188,7 @@ def preprocess_for_train(image,
Returns: Returns:
3-D float Tensor of distorted image used for training with range [-1, 1]. 3-D float Tensor of distorted image used for training with range [-1, 1].
""" """
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): with tf.compat.v1.name_scope(scope, 'distort_image', [image, height, width, bbox]):
if bbox is None: if bbox is None:
bbox = tf.constant( bbox = tf.constant(
[0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) [0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
...@@ -198,7 +198,7 @@ def preprocess_for_train(image, ...@@ -198,7 +198,7 @@ def preprocess_for_train(image,
# the coordinates are ordered [ymin, xmin, ymax, xmax]. # the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box = tf.image.draw_bounding_boxes( image_with_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), bbox) tf.expand_dims(image, 0), bbox)
tf.summary.image('image_with_bounding_boxes', image_with_box) tf.compat.v1.summary.image('image_with_bounding_boxes', image_with_box)
distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
# Restore the shape since the dynamic slice based upon the bbox_size loses # Restore the shape since the dynamic slice based upon the bbox_size loses
...@@ -206,8 +206,8 @@ def preprocess_for_train(image, ...@@ -206,8 +206,8 @@ def preprocess_for_train(image,
distorted_image.set_shape([None, None, 3]) distorted_image.set_shape([None, None, 3])
image_with_distorted_box = tf.image.draw_bounding_boxes( image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distorted_bbox) tf.expand_dims(image, 0), distorted_bbox)
tf.summary.image('images_with_distorted_bounding_box', tf.compat.v1.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box) image_with_distorted_box)
# This resizing operation may distort the images because the aspect # This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin # ratio is not respected. We select a resize method in a round robin
...@@ -218,11 +218,11 @@ def preprocess_for_train(image, ...@@ -218,11 +218,11 @@ def preprocess_for_train(image,
num_resize_cases = 1 if fast_mode else 4 num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector( distorted_image = apply_with_random_selector(
distorted_image, distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method=method), lambda x, method: tf.image.resize(x, [height, width], method=method),
num_cases=num_resize_cases) num_cases=num_resize_cases)
tf.summary.image('cropped_resized_image', tf.compat.v1.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0)) tf.expand_dims(distorted_image, 0))
# Randomly flip the image horizontally. # Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image) distorted_image = tf.image.random_flip_left_right(distorted_image)
...@@ -233,8 +233,8 @@ def preprocess_for_train(image, ...@@ -233,8 +233,8 @@ def preprocess_for_train(image,
lambda x, ordering: distort_color(x, ordering, fast_mode), lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=4) num_cases=4)
tf.summary.image('final_distorted_image', tf.compat.v1.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0)) tf.expand_dims(distorted_image, 0))
distorted_image = tf.subtract(distorted_image, 0.5) distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0) distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image return distorted_image
...@@ -265,7 +265,7 @@ def preprocess_for_eval(image, ...@@ -265,7 +265,7 @@ def preprocess_for_eval(image,
Returns: Returns:
3-D float Tensor of prepared image. 3-D float Tensor of prepared image.
""" """
with tf.name_scope(scope, 'eval_image', [image, height, width]): with tf.compat.v1.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32: if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Crop the central region of the image with an area containing 87.5% of # Crop the central region of the image with an area containing 87.5% of
...@@ -276,8 +276,8 @@ def preprocess_for_eval(image, ...@@ -276,8 +276,8 @@ def preprocess_for_eval(image,
if height and width: if height and width:
# Resize the image to the specified height and width. # Resize the image to the specified height and width.
image = tf.expand_dims(image, 0) image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear( image = tf.image.resize(
image, [height, width], align_corners=False) image, [height, width], method=tf.image.ResizeMethod.BILINEAR)
image = tf.squeeze(image, [0]) image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5) image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0) image = tf.multiply(image, 2.0)
......
...@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False): ...@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False):
a update_ops for execution and value tensor whose value on evaluation a update_ops for execution and value tensor whose value on evaluation
returns the total character accuracy. returns the total character accuracy.
""" """
with tf.variable_scope('CharAccuracy'): with tf.compat.v1.variable_scope('CharAccuracy'):
predictions.get_shape().assert_is_compatible_with(targets.get_shape()) predictions.get_shape().assert_is_compatible_with(targets.get_shape())
targets = tf.to_int32(targets) targets = tf.cast(targets, dtype=tf.int32)
const_rej_char = tf.constant(rej_char, shape=targets.get_shape()) const_rej_char = tf.constant(rej_char, shape=targets.get_shape())
weights = tf.to_float(tf.not_equal(targets, const_rej_char)) weights = tf.cast(tf.not_equal(targets, const_rej_char), dtype=tf.float32)
correct_chars = tf.to_float(tf.equal(predictions, targets)) correct_chars = tf.cast(tf.equal(predictions, targets), dtype=tf.float32)
accuracy_per_example = tf.div( accuracy_per_example = tf.compat.v1.div(
tf.reduce_sum(tf.multiply(correct_chars, weights), 1), tf.reduce_sum(input_tensor=tf.multiply(
tf.reduce_sum(weights, 1)) correct_chars, weights), axis=1),
tf.reduce_sum(input_tensor=weights, axis=1))
if streaming: if streaming:
return tf.contrib.metrics.streaming_mean(accuracy_per_example) return tf.contrib.metrics.streaming_mean(accuracy_per_example)
else: else:
return tf.reduce_mean(accuracy_per_example) return tf.reduce_mean(input_tensor=accuracy_per_example)
def sequence_accuracy(predictions, targets, rej_char, streaming=False): def sequence_accuracy(predictions, targets, rej_char, streaming=False):
...@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False): ...@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False):
returns the total sequence accuracy. returns the total sequence accuracy.
""" """
with tf.variable_scope('SequenceAccuracy'): with tf.compat.v1.variable_scope('SequenceAccuracy'):
predictions.get_shape().assert_is_compatible_with(targets.get_shape()) predictions.get_shape().assert_is_compatible_with(targets.get_shape())
targets = tf.to_int32(targets) targets = tf.cast(targets, dtype=tf.int32)
const_rej_char = tf.constant( const_rej_char = tf.constant(
rej_char, shape=targets.get_shape(), dtype=tf.int32) rej_char, shape=targets.get_shape(), dtype=tf.int32)
include_mask = tf.not_equal(targets, const_rej_char) include_mask = tf.not_equal(targets, const_rej_char)
include_predictions = tf.to_int32( include_predictions = tf.cast(
tf.where(include_mask, predictions, tf.compat.v1.where(include_mask, predictions,
tf.zeros_like(predictions) + rej_char)) tf.zeros_like(predictions) + rej_char), dtype=tf.int32)
correct_chars = tf.to_float(tf.equal(include_predictions, targets)) correct_chars = tf.cast(
tf.equal(include_predictions, targets), dtype=tf.float32)
correct_chars_counts = tf.cast( correct_chars_counts = tf.cast(
tf.reduce_sum(correct_chars, reduction_indices=[1]), dtype=tf.int32) tf.reduce_sum(input_tensor=correct_chars, axis=[1]), dtype=tf.int32)
target_length = targets.get_shape().dims[1].value target_length = targets.get_shape().dims[1].value
target_chars_counts = tf.constant( target_chars_counts = tf.constant(
target_length, shape=correct_chars_counts.get_shape()) target_length, shape=correct_chars_counts.get_shape())
accuracy_per_example = tf.to_float( accuracy_per_example = tf.cast(
tf.equal(correct_chars_counts, target_chars_counts)) tf.equal(correct_chars_counts, target_chars_counts), dtype=tf.float32)
if streaming: if streaming:
return tf.contrib.metrics.streaming_mean(accuracy_per_example) return tf.contrib.metrics.streaming_mean(accuracy_per_example)
else: else:
return tf.reduce_mean(accuracy_per_example) return tf.reduce_mean(input_tensor=accuracy_per_example)
...@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase): ...@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase):
A session object that should be used as a context manager. A session object that should be used as a context manager.
""" """
with self.cached_session() as sess: with self.cached_session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
sess.run(tf.local_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer())
yield sess yield sess
def _fake_labels(self): def _fake_labels(self):
...@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase): ...@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase):
return incorrect return incorrect
def test_sequence_accuracy_identical_samples(self): def test_sequence_accuracy_identical_samples(self):
labels_tf = tf.convert_to_tensor(self._fake_labels()) labels_tf = tf.convert_to_tensor(value=self._fake_labels())
accuracy_tf = metrics.sequence_accuracy(labels_tf, labels_tf, accuracy_tf = metrics.sequence_accuracy(labels_tf, labels_tf,
self.rej_char) self.rej_char)
...@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase): ...@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase):
def test_sequence_accuracy_one_char_difference(self): def test_sequence_accuracy_one_char_difference(self):
ground_truth_np = self._fake_labels() ground_truth_np = self._fake_labels()
ground_truth_tf = tf.convert_to_tensor(ground_truth_np) ground_truth_tf = tf.convert_to_tensor(value=ground_truth_np)
prediction_tf = tf.convert_to_tensor( prediction_tf = tf.convert_to_tensor(
self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0)))) value=self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0))))
accuracy_tf = metrics.sequence_accuracy(prediction_tf, ground_truth_tf, accuracy_tf = metrics.sequence_accuracy(prediction_tf, ground_truth_tf,
self.rej_char) self.rej_char)
...@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase): ...@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase):
def test_char_accuracy_one_char_difference_with_padding(self): def test_char_accuracy_one_char_difference_with_padding(self):
ground_truth_np = self._fake_labels() ground_truth_np = self._fake_labels()
ground_truth_tf = tf.convert_to_tensor(ground_truth_np) ground_truth_tf = tf.convert_to_tensor(value=ground_truth_np)
prediction_tf = tf.convert_to_tensor( prediction_tf = tf.convert_to_tensor(
self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0)))) value=self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0))))
accuracy_tf = metrics.char_accuracy(prediction_tf, ground_truth_tf, accuracy_tf = metrics.char_accuracy(prediction_tf, ground_truth_tf,
self.rej_char) self.rej_char)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment