Commit 5a2cf36f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into newavarecords

parents 258ddfc3 a829e648
......@@ -14,33 +14,21 @@
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import orbit
import tensorflow as tf
from official.modeling import performance
from official.staging.training import grad_utils
from official.staging.training import standard_runnable
from official.staging.training import utils
from official.utils.flags import core as flags_core
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model
class ResnetRunnable(standard_runnable.StandardTrainable,
standard_runnable.StandardEvaluable):
class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Implements the training and evaluation APIs for Resnet model."""
def __init__(self, flags_obj, time_callback, epoch_steps):
standard_runnable.StandardTrainable.__init__(self,
flags_obj.use_tf_while_loop,
flags_obj.use_tf_function)
standard_runnable.StandardEvaluable.__init__(self,
flags_obj.use_tf_function)
self.strategy = tf.distribute.get_strategy()
self.flags_obj = flags_obj
self.dtype = flags_core.get_tf_dtype(flags_obj)
......@@ -107,11 +95,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Handling epochs.
self.epoch_steps = epoch_steps
self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
def build_train_dataset(self):
"""See base class."""
return utils.make_distributed_dataset(
self.epoch_helper = orbit.utils.EpochHelper(epoch_steps, self.global_step)
train_dataset = orbit.utils.make_distributed_dataset(
self.strategy,
self.input_fn,
is_training=True,
......@@ -122,17 +107,20 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
.datasets_num_private_threads,
dtype=self.dtype,
drop_remainder=True)
def build_eval_dataset(self):
"""See base class."""
return utils.make_distributed_dataset(
self.strategy,
self.input_fn,
is_training=False,
data_dir=self.flags_obj.data_dir,
batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=self.dtype)
orbit.StandardTrainer.__init__(self, train_dataset,
flags_obj.use_tf_while_loop,
flags_obj.use_tf_function)
if not flags_obj.skip_eval:
eval_dataset = orbit.utils.make_distributed_dataset(
self.strategy,
self.input_fn,
is_training=False,
data_dir=self.flags_obj.data_dir,
batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=self.dtype)
orbit.StandardEvaluator.__init__(self, eval_dataset,
flags_obj.use_tf_function)
def train_loop_begin(self):
"""See base class."""
......
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
![TensorFlow Requirement: 2.x](https://img.shields.io/badge/TensorFlow%20Requirement-2.x-brightgreen)
# Orbit
Orbit is a customized training loop library built on top of Tensorflow 2. It
provides a flexible lightweight library that users can easily use or fork when
writing [customized training loop code](https://www.tensorflow.org/tutorials/distribute/custom_training)
in TF2. It intergates with `tf.distribute` seamlessly and supports running on
different device types (CPU, GPU, and TPU).
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Orbit package definition."""
from orbit import utils
from orbit.controller import Controller
from orbit.runner import *
from orbit.standard_runner import *
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,51 +15,54 @@
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import time
from typing import Callable, Optional, Text, Union
from absl import logging
from orbit import runner
from orbit import utils
import tensorflow as tf
import tensorflow.compat.v2 as tf
from typing import Callable, Dict, Optional, Text
def _log_info(message: Text):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging.info(message)
print(message)
from official.staging.training import utils
def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name: str):
if interval and steps_per_loop and (interval % steps_per_loop != 0):
raise ValueError("The {} interval ({}) must be a multiple "
"of the steps_per_loop ({})".format(
interval_name, interval, steps_per_loop))
class Controller(object):
class Controller:
"""Class that facilitates training and evaluation of models."""
def __init__(
self,
strategy: Optional[tf.distribute.Strategy] = None,
train_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
eval_fn: Optional[Callable[[tf.Tensor],
Optional[Dict[Text, tf.Tensor]]]] = None,
trainer: Optional[runner.AbstractTrainer] = None,
evaluator: Optional[runner.AbstractEvaluator] = None,
global_step: Optional[tf.Variable] = None,
# Train related
train_steps: Optional[int] = None,
steps_per_loop: Optional[int] = None,
summary_dir: Optional[Text] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# summary related
# Summary related
summary_interval: Optional[int] = None,
summary_dir: Optional[Text] = None,
# Evaluation related
eval_summary_dir: Optional[Text] = None,
eval_steps: Optional[int] = None,
eval_interval: Optional[int] = None):
eval_summary_dir: Optional[Text] = None):
"""Constructs a `Controller` instance.
Args:
strategy: An instance of `tf.distribute.Strategy`.
train_fn: A callable defined as `def train_fn(num_steps)`, which
`num_steps` indicates the number of steps to run for each loop.
eval_fn: A callable defined as `def eval_fn(num_steps)`, which `num_steps`
indicates the number of steps for one evaluation.
trainer: An instance of `orbit.AbstractTrainer`, which represents model
training details.
evaluator: An instance of `orbit.AbstractEvaluator`, which represents
model evaluation details.
global_step: An integer `tf.Variable` indicating the global training step
number. Usually this can be obtained from `iterations` property of the
model's optimizer (e.g. `self.optimizer.iterations`), or users can
......@@ -66,259 +70,328 @@ class Controller(object):
own global step variable, it is recommended to create the `tf.Variable`
inside strategy scope, and with
`aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
train_steps: The total (maximum) number of training steps to perform.
steps_per_loop: The number of steps to run in each "inner loop" of
training (passed to the `num_steps` parameter of `train_fn`).
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
training (passed to the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`.
summary_interval: Step interval for training summaries. Note that this
argument only applies to the summaries outside the training loop. If the
value is None, then training summaries are not enabled.
argument only applies to the summaries inside `trainer.train` function.
Summaries outside like "steps_per_second" and outputs from
`trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip evaluation
in the middle of training. Note that evaluation only happens outside the
training loop, which the loop iteration is specify by `steps_per_loop`
parameter.
Raises:
ValueError: If both `train_fn` and `eval_fn` are None.
ValueError: If `train_fn` is not None and `train_steps` is None.
ValueError: If `steps_per_loop` is None when `train_fn` is provided.
ValueError: If both `trainer` and `evaluator` are None.
ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `summary_interval` is not a positive integer or it cannot
be divisible by `steps_per_loop`.
"""
if train_fn is None and eval_fn is None:
raise ValueError("`train_fn` and `eval_fn` should not both be None")
# TODO(rxsang): Support training until exhaustion by passing
# `train_steps=-1`. Currently it cannot be supported with a host training
# loop because break statements are not supported with distributed dataset.
if train_fn is not None:
if train_steps is None:
raise ValueError("`train_steps` is required when `train_fn` is "
"provided.")
if trainer is None and evaluator is None:
raise ValueError("`trainer` and `evaluator` should not both be None")
if trainer is not None:
if steps_per_loop is None:
raise ValueError("`steps_per_loop` is required when `train_fn is "
raise ValueError("`steps_per_loop` is required when `trainer` is "
"provided.")
if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
raise ValueError("`steps_per_loop` should be a positive integer")
if summary_interval is not None and summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0")
if summary_interval is not None:
if summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0")
_validate_interval(
summary_interval, steps_per_loop, interval_name="summary")
self.trainer = trainer
self.evaluator = evaluator
self.strategy = strategy or tf.distribute.get_strategy()
self.train_fn = train_fn
self.eval_fn = eval_fn
self.global_step = global_step
self.checkpoint_manager = checkpoint_manager
if self.train_fn is not None:
self.train_steps = train_steps
self.steps_per_loop = steps_per_loop
if summary_dir:
self.summary_dir = summary_dir
elif checkpoint_manager:
self.summary_dir = checkpoint_manager.directory
else:
self.summary_dir = None
if summary_dir is None and checkpoint_manager:
summary_dir = checkpoint_manager.directory
if self.trainer is not None:
self.step_timer = None
self.steps_per_loop = steps_per_loop
self.summary_interval = summary_interval
if self.summary_dir and self.summary_interval:
summary_writer = tf.summary.create_file_writer(self.summary_dir)
else:
summary_writer = None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager(
summary_writer,
tf.summary.scalar,
global_step=self.global_step,
summary_interval=self.summary_interval)
if self.eval_fn is not None:
eval_summary_dir = eval_summary_dir or self.summary_dir
eval_summary_writer = tf.summary.create_file_writer(
eval_summary_dir) if eval_summary_dir else None
self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
self.eval_steps = eval_steps
self.eval_interval = eval_interval
# Creates and initializes the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy()) # pytype: disable=attribute-error
if self.global_step:
summary_dir, tf.summary.scalar, global_step=self.global_step)
eval_summary_writer = None
if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir
if eval_summary_dir == summary_dir and self.trainer is not None:
# Reuse the summary writer if train and evaluation summary directory
# are the same.
self.eval_summary_manager = self.summary_manager
else:
self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)
if self.global_step is not None:
tf.summary.experimental.set_step(self.global_step)
# Restores the model if needed.
# TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None:
model_restored = self._restore_model()
if not model_restored and self.checkpoint_manager.checkpoint_interval:
# If the model is not restored from a checkpoint, save an initial
checkpoint_interval = self.checkpoint_manager.checkpoint_interval
_validate_interval(
checkpoint_interval, steps_per_loop, interval_name="checkpoint")
model_restored = self.restore_checkpoint()
if not model_restored and (checkpoint_interval and
self.trainer is not None):
# If the model is not restored from a checkpoint, and
# `checkpoint_interval` is enabled for training, save an initial
# checkpoint.
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=self.global_step)
logging.info("Saved checkpoins in %s", ckpt_path)
self.save_checkpoint()
def _restore_model(self, checkpoint_path=None):
"""Restore or initialize the model.
def train(self, steps: int, checkpoint_at_completion: bool = True):
"""Runs training.
This method calls the `train` method on the Trainable object until the
global step count is equal to `steps`. It will optionally save checkpoints,
if a CheckpointManager was passed to the Controller instance's `__init__`.
Args:
checkpoint_path: An optional string indicates the checkpoint path to
restore. If None, will restore from `self.checkpoint_manager`.
steps: The global step count to train up to.
checkpoint_at_completion: Whether to save a checkpoint when this method
returns. Defaults to True (write the checkpoint). This is always
triggered, regardless of the checkpointing interval.
"""
if self.trainer is None:
raise ValueError("`self.trainer` is required when calling `train` "
"method.")
if self.global_step is None:
raise ValueError("`self.global_step` is required when calling `train` "
"method.")
# TODO(momernick): Support steps=None or -1 (training to exhaustion).
current_step = self.global_step.numpy() # This is an expensive access.
while current_step < steps:
logging.info("Train at step %s of %s", current_step, steps)
# Calculates steps to run for the next train loop.
num_steps = min(steps - current_step, self.steps_per_loop)
self._train_n_steps(num_steps)
self._maybe_save_checkpoint()
current_step = self.global_step.numpy() # This is an expensive access.
if checkpoint_at_completion:
self.save_checkpoint()
def evaluate(self, steps: int = None):
"""Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps`
steps, then writes the returned summaries (if any).
Args:
steps: The number of steps to evaluate for.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided.
Returns:
True if the latest checkpoint is found or restored. Otherwise False.
"""
with self.strategy.scope():
# Checkpoint restoring should be inside scope. b/139450638
if checkpoint_path is not None:
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
return True
return self.checkpoint_manager.restore_or_initialize()
if self.evaluator is None:
raise ValueError("`evaluator` must be provided to call `evaluate()` "
"method.")
def _evaluate_once(self, current_step):
"""Runs the evaluation once."""
logging.info("Start evaluation at step: %s", current_step)
steps = steps or -1
current_step = self.global_step.numpy()
if steps > 0:
logging.info("Running %s steps of evaluation at train step: %s", steps,
current_step)
steps = tf.convert_to_tensor(steps, dtype=tf.int32)
else:
logging.info("Evaluating at train step: %s", current_step)
with self.eval_summary_manager.summary_writer.as_default():
eval_outputs = self.eval_fn(self.eval_steps)
eval_outputs = self.evaluator.evaluate(steps)
if eval_outputs:
eval_outputs = tf.nest.map_structure(lambda x: x.numpy(), eval_outputs)
eval_outputs = tf.nest.map_structure(utils.get_value, eval_outputs)
info = "step: {} evaluation metric: {}".format(
current_step, eval_outputs)
self._log_info(info)
_log_info(info)
self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush()
def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=not force_trigger)
if ckpt_path is not None:
logging.info("Saved checkpoins in %s", ckpt_path)
def restore_checkpoint(self, checkpoint_path: Text = None):
"""Restore or initialize the model.
Args:
checkpoint_path: An optional string indicates the checkpoint path to
restore. If None, will restore from `self.checkpoint_manager`.
Returns:
The path to the restored checkpoint if a restore happened, or None
if no restore occurred.
"""
with self.strategy.scope():
# Checkpoint restoring should be inside scope. b/139450638
if checkpoint_path is not None:
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
return checkpoint_path
return self.checkpoint_manager.restore_or_initialize()
def save_checkpoint(self):
"""Checkpoint the model.
def _maybe_evaluate(self, current_step, force_trigger=False):
if self.eval_trigger(current_step, force_trigger):
self._evaluate_once(current_step)
This method will write a checkpoint containing the current state of the
model.
def _log_info(self, message):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging.info(message)
print(message)
Raises:
ValueError: if no CheckpointManager was provided to this Controller's
init args.
"""
self._maybe_save_checkpoint(force_trigger=True)
def train(self, evaluate=True):
"""Runs the training, with optional evaluation.
def train_and_evaluate(self,
train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
"""Train and evaluate in an interleaved manner.
This handles evaluation, gathering summaries, and saving checkpoints.
This method will train the model until the global step count equals
`train_steps`, running an evaluation for `eval_steps` every `eval_interval`
training steps. In addition, this method will run a final evaluation at the
end of the training sequence.
Args:
evaluate: A boolean indicates whether to perform evaluation during
training.
train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If None,
this method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evalutions.
Must be a multiple of the controller's `steps_per_loop` init arg. If
None, evaluation will only be performed after training is complete.
Raises:
RuntimeError: If `global_step` is not updated correctly in `train_fn`.
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
if self.train_fn is None:
raise ValueError("`self.train_fn` is required when calling `train` "
"method.")
if self.global_step is None:
raise ValueError("`self.global_step` is required when calling `train` "
"method.")
if evaluate and self.eval_fn is None:
raise ValueError("`self.eval_fn` is required when calling `train` method "
"with `evaluate=True`")
step_timer = _StepTimer(self.global_step)
current_step = self.global_step.numpy()
logging.info("Train at step %s of %s", current_step, self.train_steps)
while current_step < self.train_steps:
# Calculates steps to run for the next train loop.
steps_per_loop = min(self.train_steps - current_step, self.steps_per_loop)
logging.info("Entering training loop with %s steps, at step %s of %s",
steps_per_loop, current_step, self.train_steps)
current_step += steps_per_loop
steps_per_loop = tf.convert_to_tensor(steps_per_loop, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default():
train_outputs = self.train_fn(steps_per_loop)
# Updates and verifies the current step after a training loop finishes.
if current_step != self.global_step.numpy():
raise RuntimeError("`self.train_fn` is not updating `global_step` "
"correctly, expected: %s, actual: %s" %
(current_step, self.global_step.numpy()))
# Print information like metrics and steps_per_second after a training
# loop.
if train_outputs:
train_outputs = tf.nest.map_structure(
lambda x: x.numpy(), train_outputs)
steps_per_second = step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f} {}".format(
current_step, steps_per_second, train_outputs)
self._log_info(info)
train_outputs = train_outputs or {}
train_outputs["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_outputs)
self._maybe_save_checkpoints(current_step)
if evaluate:
self._maybe_evaluate(current_step)
self.summary_manager.write_summaries(train_outputs, always_write=True)
self.summary_manager.flush()
self._maybe_save_checkpoints(current_step, force_trigger=True)
if evaluate:
self._maybe_evaluate(current_step, force_trigger=True)
def evaluate(self, continuous=False, timeout_fn=None):
"""Runs the evaluation.
_validate_interval(eval_interval, self.steps_per_loop, interval_name="eval")
current_step = self.global_step.numpy() # This is an expensive access.
eval_interval = eval_interval or (train_steps - current_step)
while current_step < train_steps:
interval = min(train_steps - current_step, eval_interval)
num_steps = current_step + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
current_step = self.global_step.numpy() # This is an expensive access.
self.save_checkpoint()
def evaluate_continuously(self,
steps: int = None,
timeout: Optional[Union[int, float]] = None,
timeout_fn: Optional[Callable[[], bool]] = None):
"""Monitor a directory and evaluate on checkpoints in it.
This method continuously monitors a directory as specified by this
Controller's CheckpointManager init arg and runs evaluation on the
checkpoints found there.
Args:
continuous: If `True`, will continously monitor the checkpoint directory
to evaluate on the latest checkpoint. If `False`, will do the evaluation
once.
steps: The number of steps to run when evaluating.
timeout: The maximum number of seconds to wait between checkpoints. See
tf.train.checkpoints_iterator documentation.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` was not provided as a controller init arg.
"""
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory,
timeout=timeout,
timeout_fn=timeout_fn):
self.restore_checkpoint(checkpoint_path)
self.evaluate(steps)
def _train_n_steps(self, num_steps: int):
"""Run training for `num_steps`.
It will also write training outputs to summaries if there is any.
Args:
num_steps: An integer indicates how many steps to run for this training
loop.
Raises:
RuntimeError: If `global_step` is not updated correctly in
`trainer.train`.
"""
if not self.step_timer:
self.step_timer = StepTimer(self.global_step)
# Calculates steps to run for the next train loop.
current_step = self.global_step.numpy()
logging.info("Entering training loop at step %s to run %s steps",
current_step, num_steps)
current_step += num_steps
num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default():
# Create a lambda that returns true when summaries should be written.
should_record = False # Allows static optimization in no-summary cases.
if self.summary_interval:
should_record = lambda: (self.global_step % self.summary_interval == 0)
with tf.summary.record_if(should_record):
train_outputs = self.trainer.train(num_steps)
# Updates and verifies the current step after a training loop finishes.
if current_step != self.global_step.numpy():
raise RuntimeError("`trainer.train` function is not updating "
"`global_step` correctly, expected: %s, actual: %s" %
(current_step, self.global_step.numpy()))
# Print information like metrics and steps_per_second after a training
# loop.
if train_outputs:
train_outputs = tf.nest.map_structure(utils.get_value, train_outputs)
train_outputs = train_outputs or {}
steps_per_second = self.step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f} {}".format(
current_step, steps_per_second, train_outputs)
_log_info(info)
train_outputs["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_outputs)
def _maybe_save_checkpoint(self, force_trigger: bool = False):
"""Save checkpoints if necessary.
Args:
force_trigger: A boolean indicates whether to force saving checkpoints
regardless of the checkpoint interval.
Returns:
A boolean indicating whether a checkpoint was saved.
"""
if self.eval_fn is None:
raise ValueError("`self.eval_fn` should not be None to call "
"`evaluate()` method.")
if not continuous and timeout_fn is not None:
raise ValueError("`timeout_fn` can be only passed when `continuous` is "
"True")
if continuous:
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory, timeout_fn=timeout_fn):
self._restore_model(checkpoint_path)
self._evaluate_once(self.global_step.numpy())
return
latest_checkpoint = self.checkpoint_manager.latest_checkpoint
if not latest_checkpoint:
raise ValueError("no checkpoint found in dir %s" %
self.checkpoint_manager.directory)
self._restore_model()
self._evaluate_once(self.global_step.numpy())
class _StepTimer(object):
if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=not force_trigger)
if ckpt_path is not None:
logging.info("Saved checkpoints in %s", ckpt_path)
return True
return False
class StepTimer:
"""Utility class for measuring steps/second."""
def __init__(self, step):
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,35 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.staging.training.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for orbit.controller."""
import os
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.staging.training import controller
from official.staging.training import standard_runnable
from orbit import controller
from orbit import standard_runner
def all_strategy_combinations():
"""Gets combinations of distribution strategies."""
return combinations.combine(
strategy=[
strategy_combinations.one_device_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode="eager",
)
import tensorflow as tf
def create_model():
......@@ -57,7 +39,7 @@ def summaries_with_matching_keyword(keyword, summary_dir):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
tf.compat.v1.logging.error(event)
logging.info(event)
yield event.summary
......@@ -69,30 +51,33 @@ def check_eventfile_for_keyword(keyword, summary_dir):
def dataset_fn(ctx):
del ctx
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
targets = np.ones((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True)
return dataset
class TestRunnable(standard_runnable.StandardTrainable,
standard_runnable.StandardEvaluable):
class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
standard_runnable.StandardTrainable.__init__(self)
standard_runnable.StandardEvaluable.__init__(self)
def __init__(self, return_numpy=False):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop()
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
self.return_numpy = return_numpy
train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
eval_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
standard_runner.StandardTrainer.__init__(self, train_dataset)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def train_step(self, iterator):
......@@ -101,7 +86,7 @@ class TestRunnable(standard_runnable.StandardTrainable,
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
......@@ -109,8 +94,9 @@ class TestRunnable(standard_runnable.StandardTrainable,
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
train_loss = self.train_loss.result()
return {
"loss": self.train_loss.result(),
"loss": train_loss.numpy() if self.return_numpy else train_loss,
}
def build_eval_dataset(self):
......@@ -126,39 +112,110 @@ class TestRunnable(standard_runnable.StandardTrainable,
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
self.eval_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
eval_loss = self.eval_loss.result()
return {
"eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,
}
class TestEvaluator(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
eval_dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def eval_reduce(self, state, output):
state.append(output)
return state
def eval_begin(self):
return []
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
return loss
per_replica_losses = self.strategy.run(
_replicated_step, args=(next(iterator),))
mean_loss = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return mean_loss
def eval_end(self, outputs):
return {
"eval_loss": self.eval_loss.result(),
"eval_loss": tf.reduce_mean(outputs),
}
class TestTrainerWithSummaries(standard_runner.StandardTrainer):
"""A Trainer model with summaries for testing purposes."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
standard_runner.StandardTrainer.__init__(
self, train_dataset, use_tpu_summary_optimization=True)
def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
tf.summary.scalar("loss", loss)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ControllerTest, self).setUp()
super().setUp()
self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self):
test_runnable = TestRunnable()
test_runner = TestRunner()
# No checkpoint manager and no strategy.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
......@@ -171,51 +228,46 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runnable.global_step.assign(0)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
test_runner.global_step.assign(0)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
def test_no_checkpoint_and_summaries(self):
test_runnable = TestRunnable()
test_runner = TestRunner()
# No checkpoint + summary directories.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
@combinations.generate(all_strategy_combinations())
def test_train_and_evaluate(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
@parameterized.named_parameters(("return_numpy", True),
("return_tensor", False))
def test_train_and_evaluate(self, return_numpy):
test_runner = TestRunner(return_numpy=return_numpy)
checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
......@@ -232,31 +284,26 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations())
def test_train_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
def test_train_only(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
global_step=test_runnable.global_step,
train_steps=10,
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(evaluate=False)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
......@@ -270,29 +317,23 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations())
def test_evaluate_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
def test_evaluate_only(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runnable.model)
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step)
step_counter=test_runner.global_step)
test_controller = controller.Controller(
strategy=strategy,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.evaluate()
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=2)
# Only eval summaries are written
self.assertFalse(
......@@ -303,6 +344,207 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done")
def timeout_fn():
with tf.io.gfile.GFile(done_file, "w") as f:
f.write("DONE")
return True
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate_continuously(
timeout=1, timeout_fn=timeout_fn, steps=2)
self.assertNotEmpty(tf.io.gfile.glob(done_file))
def test_no_eval_steps(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager)
test_controller.evaluate()
def test_already_trained_model(self):
test_runner = TestRunner()
test_runner.global_step.assign(10)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
# `global_step` is already `train_steps`.
test_controller.train(steps=10)
def test_summaries_inside_train_fn(self):
test_runner = TestTrainerWithSummaries()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_with_same_summary_dir(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("loss",
os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("eval_loss",
os.path.join(self.model_dir, "summaries")))
def test_early_stop_on_eval_loss(self):
test_runner = TestRunner()
class EarlyStopController(controller.Controller):
"""A subclass of Controller supports early stopping."""
def train_and_evaluate(self,
train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
while self.global_step.numpy() < train_steps:
interval = min(train_steps - self.global_step.numpy(), eval_interval)
num_steps = self.global_step.numpy() + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
# Early stop condition.
if test_runner.eval_loss.result() < 0.1:
logging.info(
"Training early stopped as eval_loss %s is less than 0.1",
test_runner.eval_loss.result())
return
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = EarlyStopController(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=6, eval_interval=2)
self.assertLess(test_runner.global_step, 10)
def test_evaluate_with_loss_outputs(self):
test_evaluator = TestEvaluator()
checkpoint = tf.train.Checkpoint(model=test_evaluator.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.model_dir, max_to_keep=None)
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=5)
# Only eval summaries are written
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_reset_datasets(self):
test_runner = TestRunner()
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
train_dataset = (
test_runner.strategy.experimental_distribute_datasets_from_function(
dataset_fn))
eval_dataset = (
test_runner.strategy.experimental_distribute_datasets_from_function(
dataset_fn))
test_runner.train_dataset = train_dataset
test_runner.eval_dataset = eval_dataset
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
if __name__ == "__main__":
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,19 +15,12 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import six
import tensorflow.compat.v2 as tf
from typing import Dict, Optional, Text
import tensorflow as tf
@six.add_metaclass(abc.ABCMeta)
class AbstractTrainable(tf.Module):
class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the APIs required for training."""
@abc.abstractmethod
......@@ -50,14 +44,13 @@ class AbstractTrainable(tf.Module):
one update to model parameters, e.g. if training a GAN).
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries.
"""
pass
@six.add_metaclass(abc.ABCMeta)
class AbstractEvaluable(tf.Module):
class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the APIs required for evaluation."""
@abc.abstractmethod
......@@ -73,7 +66,7 @@ class AbstractEvaluable(tf.Module):
is `None`.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries.
"""
pass
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,67 +15,101 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import six
import tensorflow.compat.v2 as tf
from typing import Dict, Optional, Text
from official.staging.training import runnable
from official.staging.training import utils
from typing import Any, Dict, Optional, Text
import dataclasses
from orbit import runner
from orbit import utils
import tensorflow as tf
@dataclasses.dataclass(frozen=True)
class TrainerOverrides:
"""Advanced overrides for Orbit trainers.
Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
use_tf_while_loop: bool = True
use_tf_function: bool = True
use_tpu_summary_optimization: bool = False
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs."""
def __init__(self,
train_dataset,
use_tf_while_loop=True,
use_tf_function=True,
use_tpu_summary_optimization=False):
"""Construct a `StandardTrainer` object.
@six.add_metaclass(abc.ABCMeta)
class StandardTrainable(runnable.AbstractTrainable):
"""Implements the standard functionality of AbstractTrainable APIs."""
def __init__(self, use_tf_while_loop=True, use_tf_function=True):
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
if use_tf_while_loop and not use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
self.use_tf_while_loop = use_tf_while_loop
self.use_tf_function = use_tf_function
self.train_dataset = None
self.train_iter = None
self.train_loop_fn = None
@abc.abstractmethod
def build_train_dataset(self):
"""Builds the training datasets.
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
"""
pass
if use_tpu_summary_optimization and not use_tf_while_loop:
raise ValueError("`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported")
self._use_tf_while_loop = use_tf_while_loop
self._use_tf_function = use_tf_function
self._train_dataset = train_dataset
self._train_iter = None
self._train_loop_fn = None
self._use_tpu_summary_optimization = use_tpu_summary_optimization
def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class."""
if self.train_dataset is None:
# Build train input dataset
self.train_dataset = self.build_train_dataset()
self.train_iter = tf.nest.map_structure(iter, self.train_dataset)
self.train_loop_begin()
if self._train_iter is None:
self._train_iter = tf.nest.map_structure(iter, self.train_dataset)
if self.train_loop_fn is None:
if self._train_loop_fn is None:
train_fn = self.train_step
if self.use_tf_while_loop:
self.train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
if self._use_tf_while_loop:
self._train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
if self._use_tpu_summary_optimization:
self._train_loop_fn = utils.train_function_with_summaries(
self._train_loop_fn)
else:
self._train_loop_fn = tf.function(self._train_loop_fn)
else:
if self.use_tf_function:
if self._use_tf_function:
train_fn = tf.function(train_fn)
self.train_loop_fn = utils.create_loop_fn(train_fn)
self._train_loop_fn = utils.create_loop_fn(train_fn)
self.train_loop_begin()
self.train_loop_fn(self.train_iter, num_steps)
self._train_loop_fn(self._train_iter, num_steps)
return self.train_loop_end()
def train_loop_begin(self):
"""Called once at the beginning of the training loop.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over multiple
steps of training.
"""
......@@ -107,54 +142,85 @@ class StandardTrainable(runnable.AbstractTrainable):
"""
pass
@property
def train_dataset(self):
"""Returns the train_dataset instance."""
return self._train_dataset
@six.add_metaclass(abc.ABCMeta)
class StandardEvaluable(runnable.AbstractEvaluable):
"""Implements the standard functionality of AbstractEvaluable APIs."""
@train_dataset.setter
def train_dataset(self, train_dataset):
"""Set a new train dataset and replace with the existing one.
def __init__(self, use_tf_function=True):
self.eval_use_tf_function = use_tf_function
self.eval_dataset = None
self.eval_loop_fn = None
Any unfinished work in the previous dataset will be discarded.
@abc.abstractmethod
def build_eval_dataset(self):
"""Builds the evaluation datasets.
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
"""
self._train_dataset = train_dataset
self._train_iter = None
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
@dataclasses.dataclass(frozen=True)
class EvaluatorOverrides:
"""Advanced overrides for Orbit evaluators.
Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
"""
use_tf_function: bool = True
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True):
"""Construct a `StandardEvaluator` object.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, evaluation will run on pure eager mode.
"""
pass
self._eval_use_tf_function = use_tf_function
self._eval_dataset = eval_dataset
self._eval_loop_fn = None
def evaluate(
self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class."""
if self.eval_dataset is None:
# Build train input dataset
self.eval_dataset = self.build_eval_dataset()
outputs = self.eval_begin() # pylint: disable=assignment-from-no-return
if self.eval_loop_fn is None:
eval_iter = tf.nest.map_structure(iter, self._eval_dataset)
if self._eval_loop_fn is None:
eval_fn = self.eval_step
if self.eval_use_tf_function:
if self._eval_use_tf_function:
eval_fn = tf.function(eval_fn)
self.eval_loop_fn = utils.create_loop_fn(eval_fn)
self._eval_loop_fn = utils.create_loop_fn(eval_fn)
eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
outputs = self._eval_loop_fn(
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
if outputs is None:
return self.eval_end()
else:
return self.eval_end(outputs)
self.eval_begin()
self.eval_loop_fn(eval_iter, num_steps)
return self.eval_end()
def eval_begin(self):
def eval_begin(self) -> Any:
"""Called once at the beginning of the evaluation.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over the entire
evaluation.
Returns:
An output which is passed as `state` argument into `eval_reduce` function.
"""
pass
@abc.abstractmethod
def eval_step(self, iterator):
def eval_step(self, iterator) -> Any:
"""Implements one step of evaluation.
What a "step" consists of is up to the implementer. If using distribution
......@@ -165,17 +231,57 @@ class StandardEvaluable(runnable.AbstractEvaluable):
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
Returns:
An output which is passed as `step_outputs` argument into `eval_reduce`
function.
"""
pass
def eval_end(self) -> Optional[Dict[Text, tf.Tensor]]:
def eval_end(self, *args) -> Optional[Dict[Text, tf.Tensor]]:
"""Called at the end of the evaluation.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the evaluate() method.
Args:
*args: the outputs from `eval_reduce` for the last eval step.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
def eval_reduce(self, state=None, step_outputs=None) -> Any:
"""A function to do the reduction on the evaluation outputs per step.
This is useful for passing states throughout evaluation. E.g. it can be used
to maintain the output losses from all the evaluation steps, and compute the
mean loss in `eval_end` function.
Args:
state: A maintained state throughout the evaluation.
step_outputs: Outputs from the current evaluation step.
Returns:
An output which is passed as `state` argument into `eval_reduce` function
for the next step. After evaluation is finished, the output from last step
will be passed into `eval_end` function.
"""
pass
@property
def eval_dataset(self):
"""Returns the train_datase instance."""
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, eval_dataset):
"""Set a new eval dataset and replace with the existing one.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
"""
self._eval_dataset = eval_dataset
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.standard_runner."""
# pylint: disable=g-bad-import-order
from orbit import standard_runner
import tensorflow as tf
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for tests."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.global_step = tf.Variable(
0,
trainable=False,
dtype=tf.int64,
name='global_step',
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
standard_runner.StandardTrainer.__init__(self, train_dataset=None)
standard_runner.StandardEvaluator.__init__(self, eval_dataset=None)
def train_loop_begin(self):
self.train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
def train_step(self, iterator):
def _replicated_step(_):
self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
return self.global_step.numpy()
def eval_begin(self):
self.eval_dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def eval_step(self, iterator):
def _replicated_step(_):
self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
return self.global_step.numpy()
class StandardRunnerTest(tf.test.TestCase):
def test_train(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.train(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
def test_eval(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.evaluate(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
if __name__ == '__main__':
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,16 +15,13 @@
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import contextlib
import functools
import inspect
import six
import tensorflow.compat.v2 as tf
import numpy as np
import tensorflow as tf
def create_loop_fn(step_fn):
......@@ -79,7 +77,6 @@ def create_tf_while_loop_fn(step_fn):
A callable defined as the `loop_fn` defination below.
"""
@tf.function
def loop_fn(iterator, num_steps):
"""A loop function with multiple steps.
......@@ -130,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if six.PY3:
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn)
argspec = inspect.getfullargspec(dataset_or_fn)
args_names = argspec.args
if "input_context" in args_names:
......@@ -144,96 +138,62 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager(object):
class SummaryManager:
"""A class manages writing summaries."""
def __init__(self,
summary_writer,
summary_fn,
global_step=None,
summary_interval=None):
def __init__(self, summary_dir, summary_fn, global_step=None):
"""Construct a summary manager object.
Args:
summary_writer: A `tf.summary.SummaryWriter` instance for writing
summaries.
summary_dir: the directory to write summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for checking the current global step
value, in case users want to save summaries every N steps.
summary_interval: An integer, indicates the minimum step interval between
two summaries.
global_step: A `tf.Variable` instance for the global step.
"""
if summary_writer is not None:
self._summary_writer = summary_writer
self._enabled = True
else:
self._summary_writer = tf.summary.create_noop_writer()
self._enabled = False
self._enabled = (summary_dir is not None)
self._summary_dir = summary_dir
self._summary_fn = summary_fn
self._summary_writer = None
if global_step is None:
self._global_step = tf.summary.experimental.get_step()
else:
self._global_step = global_step
if summary_interval is not None:
if self._global_step is None:
raise ValueError("`summary_interval` is not None, but no `global_step` "
"can be obtained ")
self._last_summary_step = self._global_step.numpy()
self._summary_interval = summary_interval
@property
def summary_interval(self):
return self._summary_interval
@property
def summary_writer(self):
"""Returns the underlying summary writer."""
if self._summary_writer is not None:
return self._summary_writer
if self._enabled:
self._summary_writer = tf.summary.create_file_writer(self._summary_dir)
else:
self._summary_writer = tf.summary.create_noop_writer()
return self._summary_writer
def flush(self):
"""Flush the underlying summary writer."""
if self._enabled:
tf.summary.flush(self._summary_writer)
tf.summary.flush(self.summary_writer)
def write_summaries(self, items, always_write=True):
def write_summaries(self, items):
"""Write a bulk of summaries.
Args:
items: a dictionary of `Tensors` for writing summaries.
always_write: An optional boolean. If `True`, the manager will always
write summaries unless the summaries have been written for the same
step. Otherwise the manager will only write the summaries if the
interval between summaries are larger than `summary_interval`.
Returns:
A boolean indicates whether the summaries are written or not.
"""
# TODO(rxsang): Support writing summaries with nested structure, so users
# can split the summaries into different directories for nicer visualization
# in Tensorboard, like train and eval metrics.
if not self._enabled:
return False
if self._summary_interval is not None:
current_step = self._global_step.numpy()
if current_step == self._last_summary_step:
return False
if not always_write and current_step < (self._last_summary_step +
self._summary_interval):
return False
self._last_summary_step = current_step
return
with self._summary_writer.as_default():
with self.summary_writer.as_default():
for name, tensor in items.items():
self._summary_fn(name, tensor, step=self._global_step)
return True
@six.add_metaclass(abc.ABCMeta)
class Trigger(object):
class Trigger(metaclass=abc.ABCMeta):
"""An abstract class representing a "trigger" for some event."""
@abc.abstractmethod
......@@ -294,7 +254,7 @@ class IntervalTrigger(Trigger):
self._last_trigger_value = 0
class EpochHelper(object):
class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step):
......@@ -340,3 +300,86 @@ class EpochHelper(object):
@property
def current_epoch(self):
return self._current_epoch
@contextlib.contextmanager
def _soft_device_placement():
"""Context manager for soft device placement, allowing summaries on CPU."""
original_setting = tf.config.get_soft_device_placement()
try:
tf.config.set_soft_device_placement(True)
yield
finally:
tf.config.set_soft_device_placement(original_setting)
def train_function_with_summaries(*args, **kwargs):
"""Utility function to support TPU summaries via multiple `tf.function`s.
This permits interleaving summaries inside TPU-compatible code, but without
any performance impact on steps that do not write summaries.
Usage is as a decorator, similar to `tf.function`, and any `tf.function`
arguments will be passed through if supplied:
@trainer.train_function_with_summaries
def train(self, num_steps):
...
The decorated function is assumed to be a loop method accepting a `num_steps`
parameter, as for instance would be called within the `Controller`'s outer
train loop. The implementation here assumes that `summary_frequency` is
divisible by `steps_per_loop`. The decorated method should accept two
arguments, `self` and `num_steps`.
Two `tf.function` versions of `train_fn` are created: one inside a summary
writer scope with soft device placement enabled (used on steps that require
summary writing), and one with no summary writer present and soft device
placement disabled (used on all other steps).
Args:
*args: Arguments to pass through to `tf.function`.
**kwargs: Keyword arguments to pass through to `tf.function`.
Returns:
If the first argument is a callable, returns the decorated callable.
Otherwise, returns a decorator.
"""
def decorator(train_fn):
# TODO(dhr): Validate the signature of train_fn?
train_fn_with_summaries = tf.function(train_fn, *args, **kwargs)
train_fn_without_summaries = tf.function(train_fn, *args, **kwargs)
@functools.wraps(train_fn)
def wrapper(self, num_steps):
if tf.summary.should_record_summaries():
with _soft_device_placement():
output = train_fn_with_summaries(self, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
with tf.summary.record_if(False):
output = train_fn_without_summaries(self, num_steps)
return output
return wrapper
if args and callable(args[0]):
train_fn, args = args[0], args[1:]
return decorator(train_fn)
return decorator
def get_value(x) -> np.ndarray:
"""Returns the value of a variable/tensor.
Args:
x: input variable.
Returns:
A Numpy array.
"""
if not tf.is_tensor(x):
return x
return x.numpy()
## Attention-based Extraction of Structured Information from Street View Imagery
# Attention-based Extraction of Structured Information from Street View Imagery
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/attention-based-extraction-of-structured/optical-character-recognition-on-fsns-test)](https://paperswithcode.com/sota/optical-character-recognition-on-fsns-test?p=attention-based-extraction-of-structured)
[![Paper](http://img.shields.io/badge/paper-arXiv.1704.03549-B3181B.svg)](https://arxiv.org/abs/1704.03549)
......@@ -7,14 +7,20 @@
*A TensorFlow model for real-world image text extraction problems.*
This folder contains the code needed to train a new Attention OCR model on the
[FSNS dataset][FSNS] dataset to transcribe street names in France. You can
also use it to train it on your own data.
[FSNS dataset][FSNS] to transcribe street names in France. You can also train the code on your own data.
More details can be found in our paper:
["Attention-based Extraction of Structured Information from Street View
Imagery"](https://arxiv.org/abs/1704.03549)
## Description
* Paper presents a model based on ConvNets, RNN's and a novel attention mechanism.
Achieves **84.2%** on FSNS beating the previous benchmark (**72.46%**). Also studies
the speed/accuracy tradeoff that results from using CNN feature extractors of
different depths.
## Contacts
Authors
......@@ -22,7 +28,18 @@ Authors
* Zbigniew Wojna (zbigniewwojna@gmail.com)
* Alexander Gorban (gorban@google.com)
Maintainer: Xavier Gibert [@xavigibert](https://github.com/xavigibert)
Maintainer
* Xavier Gibert ([@xavigibert](https://github.com/xavigibert))
## Table of Contents
* [Requirements](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#requirements)
* [Dataset](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#dataset)
* [How to use this code](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-this-code)
* [Using your own image data](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#using-your-own-image-data)
* [How to use a pre-trained model](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-a-pre-trained-model)
* [Disclaimer](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#disclaimer)
## Requirements
......@@ -49,6 +66,42 @@ cd ..
[TF]: https://www.tensorflow.org/install/
[FSNS]: https://github.com/tensorflow/models/tree/master/research/street
## Dataset
The French Street Name Signs (FSNS) dataset is split into subsets,
each of which is composed of multiple files. Note that these datasets
are very large. The approximate sizes are:
* Train: 512 files of 300MB each.
* Validation: 64 files of 40MB each.
* Test: 64 files of 50MB each.
* The datasets download includes a directory `testdata` that contains
some small datasets that are big enough to test that models can
actually learn something.
* Total: around 158GB
The download paths are in the following list:
```
https://download.tensorflow.org/data/fsns-20160927/charset_size=134.txt
https://download.tensorflow.org/data/fsns-20160927/test/test-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/test/test-00063-of-00064
https://download.tensorflow.org/data/fsns-20160927/testdata/arial-32-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/mnist-sample-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/numbers-16-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/train/train-00000-of-00512
...
https://download.tensorflow.org/data/fsns-20160927/train/train-00511-of-00512
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00063-of-00064
```
All URLs are stored in the [research/street](https://github.com/tensorflow/models/tree/master/research/street)
repository in the text file `python/fsns_urls.txt`.
## How to use this code
To run all unit tests:
......@@ -80,7 +133,7 @@ tar xf attention_ocr_2017_08_09.tar.gz
python train.py --checkpoint=model.ckpt-399731
```
## How to use your own image data to train the model
## Using your own image data
You need to define a new dataset. There are two options:
......@@ -166,6 +219,14 @@ implement one in Python or C++.
The recommended way is to use the [Serving infrastructure][serving].
To export to SavedModel format:
```
python model_export.py \
--checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
```
Alternatively you can:
1. define a placeholder for images (or use directly an numpy array)
2. [create a graph ](https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/eval.py#L60)
......@@ -188,7 +249,7 @@ other than a one time experiment please use the [TensorFlow Serving][serving].
[1]: https://github.com/tensorflow/tensorflow/blob/aaf7adc/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
[2]: https://www.tensorflow.org/api_docs/python/tf/contrib/framework/assign_from_checkpoint_fn
[serving]: https://tensorflow.github.io/serving/serving_basic
[serving]: https://www.tensorflow.org/tfx/serving/serving_basic
## Disclaimer
......
......@@ -14,10 +14,10 @@
# ==============================================================================
"""Define flags are common for both train.py and eval.py scripts."""
import logging
import sys
from tensorflow.python.platform import flags
import logging
import datasets
import model
......@@ -35,9 +35,17 @@ logging.basicConfig(
datefmt='%Y-%m-%d %H:%M:%S')
_common_flags_defined = False
def define():
"""Define common flags."""
# yapf: disable
# common_flags.define() may be called multiple times in unit tests.
global _common_flags_defined
if _common_flags_defined:
return
_common_flags_defined = True
flags.DEFINE_integer('batch_size', 32,
'Batch size.')
......@@ -74,7 +82,7 @@ def define():
'the optimizer to use')
flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used')
'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation')
......
......@@ -56,14 +56,14 @@ def augment_image(image):
Returns:
Distorted Tensor image of the same shape.
"""
with tf.variable_scope('AugmentImage'):
with tf.compat.v1.variable_scope('AugmentImage'):
height = image.get_shape().dims[0].value
width = image.get_shape().dims[1].value
# Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image),
image_size=tf.shape(input=image),
bounding_boxes=tf.zeros([0, 0, 4]),
min_object_covered=0.8,
aspect_ratio_range=[0.8, 1.2],
......@@ -74,7 +74,7 @@ def augment_image(image):
# Randomly chooses one of the 4 interpolation methods
distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method),
lambda x, method: tf.image.resize(x, [height, width], method),
num_cases=4)
distorted_image.set_shape([height, width, 3])
......@@ -99,9 +99,10 @@ def central_crop(image, crop_size):
Returns:
A tensor of shape [crop_height, crop_width, channels].
"""
with tf.variable_scope('CentralCrop'):
with tf.compat.v1.variable_scope('CentralCrop'):
target_width, target_height = crop_size
image_height, image_width = tf.shape(image)[0], tf.shape(image)[1]
image_height, image_width = tf.shape(
input=image)[0], tf.shape(input=image)[1]
assert_op1 = tf.Assert(
tf.greater_equal(image_height, target_height),
['image_height < target_height', image_height, target_height])
......@@ -129,7 +130,7 @@ def preprocess_image(image, augment=False, central_crop_size=None,
A float32 tensor of shape [H x W x 3] with RGB values in the required
range.
"""
with tf.variable_scope('PreprocessImage'):
with tf.compat.v1.variable_scope('PreprocessImage'):
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if augment or central_crop_size:
if num_towers == 1:
......@@ -144,9 +145,6 @@ def preprocess_image(image, augment=False, central_crop_size=None,
images = [augment_image(img) for img in images]
image = tf.concat(images, 1)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.5)
return image
......@@ -185,7 +183,7 @@ def get_data(dataset,
image_orig, augment, central_crop_size, num_towers=dataset.num_of_views)
label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes)
images, images_orig, labels, labels_one_hot = (tf.train.shuffle_batch(
images, images_orig, labels, labels_one_hot = (tf.compat.v1.train.shuffle_batch(
[image, image_orig, label, label_one_hot],
batch_size=batch_size,
num_threads=shuffle_config.num_batching_threads,
......
......@@ -72,7 +72,7 @@ def read_charset(filename, null_character=u'\u2591'):
"""
pattern = re.compile(r'(\d+)\t(.+)')
charset = {}
with tf.gfile.GFile(filename) as f:
with tf.io.gfile.GFile(filename) as f:
for i, line in enumerate(f):
m = pattern.match(line)
if m is None:
......@@ -96,9 +96,9 @@ class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler):
self._num_of_views = num_of_views
def tensors_to_item(self, keys_to_tensors):
return tf.to_int64(
return tf.cast(
self._num_of_views * keys_to_tensors[self._original_width_key] /
keys_to_tensors[self._width_key])
keys_to_tensors[self._width_key], dtype=tf.int64)
def get_split(split_name, dataset_dir=None, config=None):
......@@ -133,19 +133,19 @@ def get_split(split_name, dataset_dir=None, config=None):
zero = tf.zeros([1], dtype=tf.int64)
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='png'),
tf.io.FixedLenFeature((), tf.string, default_value='png'),
'image/width':
tf.FixedLenFeature([1], tf.int64, default_value=zero),
tf.io.FixedLenFeature([1], tf.int64, default_value=zero),
'image/orig_width':
tf.FixedLenFeature([1], tf.int64, default_value=zero),
tf.io.FixedLenFeature([1], tf.int64, default_value=zero),
'image/class':
tf.FixedLenFeature([config['max_sequence_length']], tf.int64),
tf.io.FixedLenFeature([config['max_sequence_length']], tf.int64),
'image/unpadded_class':
tf.VarLenFeature(tf.int64),
tf.io.VarLenFeature(tf.int64),
'image/text':
tf.FixedLenFeature([1], tf.string, default_value=''),
tf.io.FixedLenFeature([1], tf.string, default_value=''),
}
items_to_handlers = {
'image':
......@@ -171,12 +171,14 @@ def get_split(split_name, dataset_dir=None, config=None):
config['splits'][split_name]['pattern'])
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
reader=tf.compat.v1.TFRecordReader,
decoder=decoder,
num_samples=config['splits'][split_name]['size'],
items_to_descriptions=config['items_to_descriptions'],
# additional parameters for convenience.
charset=charset,
charset_file=charset_file,
image_shape=config['image_shape'],
num_char_classes=len(charset),
num_of_views=config['num_of_views'],
max_sequence_length=config['max_sequence_length'],
......
......@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase):
image_tf, label_tf = provider.get(['image', 'label'])
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())
with slim.queues.QueueRunners(sess):
image_np, label_np = sess.run([image_tf, label_tf])
......
......@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5
print('Downloading %s ...' % URL)
urllib.request.urlretrieve(URL, DST_ORIG)
print('Writing %d records from %s to %s ...' % (KEEP_NUM_RECORDS, DST_ORIG, DST))
print('Writing %d records from %s to %s ...' %
(KEEP_NUM_RECORDS, DST_ORIG, DST))
with tf.io.TFRecordWriter(DST) as writer:
for raw_record in itertools.islice(tf.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS):
for raw_record in itertools.islice(tf.compat.v1.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS):
writer.write(raw_record)
......@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for i in range(batch_size):
path = file_pattern % i
print("Reading %s" % path)
pil_image = PIL.Image.open(tf.gfile.GFile(path, 'rb'))
pil_image = PIL.Image.open(tf.io.gfile.GFile(path, 'rb'))
images_actual_data[i, ...] = np.asarray(pil_image)
return images_actual_data
......@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
raw_images = tf.compat.v1.placeholder(
tf.uint8, shape=[batch_size, height, width, 3])
images = tf.map_fn(data_provider.preprocess_image, raw_images,
dtype=tf.float32)
endpoints = model.create_base(images, labels_one_hot=None)
......@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_data = load_images(image_path_pattern, batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()]
......@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
def main(_):
print("Predicted strings:")
predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
FLAGS.image_path_pattern)
FLAGS.image_path_pattern)
for line in predictions:
print(line)
if __name__ == '__main__':
tf.app.run()
tf.compat.v1.app.run()
......@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase):
super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename),
self.assertTrue(tf.io.gfile.exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
self._batch_size = 32
tf.flags.FLAGS.dataset_dir = os.path.join(os.path.dirname(__file__), 'datasets/testdata/fsns')
tf.flags.FLAGS.dataset_dir = os.path.join(
os.path.dirname(__file__), 'datasets/testdata/fsns')
def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32
......@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase):
images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name(
tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_tf = tf.compat.v1.get_default_graph().get_tensor_by_name(
tensor_name + ':0')
reader = tf.compat.v1.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT)
checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data})
......@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase):
'fsns',
image_path_pattern)
self.assertEqual([
u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
], predictions)
......
......@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None,
def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir)
if not tf.io.gfile.exists(FLAGS.eval_log_dir):
tf.io.gfile.makedirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
......@@ -62,7 +62,7 @@ def main(_):
eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0})
session_config = tf.compat.v1.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir,
......
......@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases):
The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically.
"""
sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
sel = tf.random.uniform([], maxval=num_cases, dtype=tf.int32)
# Pass the real x only to one of the func calls.
return control_flow_ops.merge([
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
......@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
Raises:
ValueError: if color_ordering not in [0, 3]
"""
with tf.name_scope(scope, 'distort_color', [image]):
with tf.compat.v1.name_scope(scope, 'distort_color', [image]):
if fast_mode:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
......@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image,
Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox
"""
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
with tf.compat.v1.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
......@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image,
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image),
image_size=tf.shape(input=image),
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
......@@ -188,7 +188,7 @@ def preprocess_for_train(image,
Returns:
3-D float Tensor of distorted image used for training with range [-1, 1].
"""
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
with tf.compat.v1.name_scope(scope, 'distort_image', [image, height, width, bbox]):
if bbox is None:
bbox = tf.constant(
[0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
......@@ -198,7 +198,7 @@ def preprocess_for_train(image,
# the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), bbox)
tf.summary.image('image_with_bounding_boxes', image_with_box)
tf.compat.v1.summary.image('image_with_bounding_boxes', image_with_box)
distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
# Restore the shape since the dynamic slice based upon the bbox_size loses
......@@ -206,8 +206,8 @@ def preprocess_for_train(image,
distorted_image.set_shape([None, None, 3])
image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distorted_bbox)
tf.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box)
tf.compat.v1.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box)
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
......@@ -218,11 +218,11 @@ def preprocess_for_train(image,
num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method=method),
lambda x, method: tf.image.resize(x, [height, width], method=method),
num_cases=num_resize_cases)
tf.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0))
tf.compat.v1.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0))
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
......@@ -233,8 +233,8 @@ def preprocess_for_train(image,
lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=4)
tf.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0))
tf.compat.v1.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0))
distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image
......@@ -265,7 +265,7 @@ def preprocess_for_eval(image,
Returns:
3-D float Tensor of prepared image.
"""
with tf.name_scope(scope, 'eval_image', [image, height, width]):
with tf.compat.v1.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Crop the central region of the image with an area containing 87.5% of
......@@ -276,8 +276,8 @@ def preprocess_for_eval(image,
if height and width:
# Resize the image to the specified height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(
image, [height, width], align_corners=False)
image = tf.image.resize(
image, [height, width], method=tf.image.ResizeMethod.BILINEAR)
image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment