Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
31ca3b97
Commit
31ca3b97
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
resovle merge conflicts
parents
3e9d886d
7fcd7cba
Changes
392
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1658 additions
and
208 deletions
+1658
-208
orbit/README.md
orbit/README.md
+9
-0
orbit/__init__.py
orbit/__init__.py
+21
-0
orbit/controller.py
orbit/controller.py
+410
-0
orbit/controller_test.py
orbit/controller_test.py
+550
-0
orbit/runner.py
orbit/runner.py
+9
-16
orbit/standard_runner.py
orbit/standard_runner.py
+254
-0
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+96
-0
orbit/utils.py
orbit/utils.py
+109
-66
research/attention_ocr/README.md
research/attention_ocr/README.md
+67
-6
research/attention_ocr/python/common_flags.py
research/attention_ocr/python/common_flags.py
+10
-2
research/attention_ocr/python/data_provider.py
research/attention_ocr/python/data_provider.py
+8
-10
research/attention_ocr/python/datasets/fsns.py
research/attention_ocr/python/datasets/fsns.py
+13
-11
research/attention_ocr/python/datasets/fsns_test.py
research/attention_ocr/python/datasets/fsns_test.py
+1
-1
research/attention_ocr/python/datasets/testdata/fsns/download_data.py
...ention_ocr/python/datasets/testdata/fsns/download_data.py
+3
-2
research/attention_ocr/python/demo_inference.py
research/attention_ocr/python/demo_inference.py
+12
-11
research/attention_ocr/python/demo_inference_test.py
research/attention_ocr/python/demo_inference_test.py
+40
-39
research/attention_ocr/python/eval.py
research/attention_ocr/python/eval.py
+3
-3
research/attention_ocr/python/inception_preprocessing.py
research/attention_ocr/python/inception_preprocessing.py
+16
-16
research/attention_ocr/python/metrics.py
research/attention_ocr/python/metrics.py
+20
-18
research/attention_ocr/python/metrics_test.py
research/attention_ocr/python/metrics_test.py
+7
-7
No files found.
orbit/README.md
0 → 100644
View file @
31ca3b97

# Orbit
Orbit is a customized training loop library built on top of Tensorflow 2. It
provides a flexible lightweight library that users can easily use or fork when
writing
[
customized training loop code
](
https://www.tensorflow.org/tutorials/distribute/custom_training
)
in TF2. It intergates with
`tf.distribute`
seamlessly and supports running on
different device types (CPU, GPU, and TPU).
orbit/__init__.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Orbit package definition."""
from
orbit
import
utils
from
orbit.controller
import
Controller
from
orbit.runner
import
*
from
orbit.standard_runner
import
*
o
fficial/staging/training
/controller.py
→
o
rbit
/controller.py
View file @
31ca3b97
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,51 +15,54 @@
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
time
from
typing
import
Callable
,
Optional
,
Text
,
Union
from
absl
import
logging
from
orbit
import
runner
from
orbit
import
utils
import
tensorflow
as
tf
def
_log_info
(
message
:
Text
):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging
.
info
(
message
)
print
(
message
)
import
tensorflow.compat.v2
as
tf
from
typing
import
Callable
,
Dict
,
Optional
,
Text
from
official.staging.training
import
utils
def
_validate_interval
(
interval
:
Optional
[
int
],
steps_per_loop
:
Optional
[
int
],
interval_name
:
str
):
if
interval
and
steps_per_loop
and
(
interval
%
steps_per_loop
!=
0
):
raise
ValueError
(
"The {} interval ({}) must be a multiple "
"of the steps_per_loop ({})"
.
format
(
interval_name
,
interval
,
steps_per_loop
))
class
Controller
(
object
)
:
class
Controller
:
"""Class that facilitates training and evaluation of models."""
def
__init__
(
self
,
strategy
:
Optional
[
tf
.
distribute
.
Strategy
]
=
None
,
train_fn
:
Optional
[
Callable
[[
tf
.
Tensor
],
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]]]
=
None
,
eval_fn
:
Optional
[
Callable
[[
tf
.
Tensor
],
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]]]
=
None
,
trainer
:
Optional
[
runner
.
AbstractTrainer
]
=
None
,
evaluator
:
Optional
[
runner
.
AbstractEvaluator
]
=
None
,
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
# Train related
train_steps
:
Optional
[
int
]
=
None
,
steps_per_loop
:
Optional
[
int
]
=
None
,
summary_dir
:
Optional
[
Text
]
=
None
,
checkpoint_manager
:
Optional
[
tf
.
train
.
CheckpointManager
]
=
None
,
#
s
ummary related
#
S
ummary related
summary_interval
:
Optional
[
int
]
=
None
,
summary_dir
:
Optional
[
Text
]
=
None
,
# Evaluation related
eval_summary_dir
:
Optional
[
Text
]
=
None
,
eval_steps
:
Optional
[
int
]
=
None
,
eval_interval
:
Optional
[
int
]
=
None
):
eval_summary_dir
:
Optional
[
Text
]
=
None
):
"""Constructs a `Controller` instance.
Args:
strategy: An instance of `tf.distribute.Strategy`.
train
_fn: A callable defined as `def train_fn(num_steps)`, which
`num_steps` indicates the number of steps to run for each loop
.
eval
_fn: A callable defined as `def eval_fn(num_steps)`, which `num_steps`
indicates the number of steps for one
evaluation.
train
er: An instance of `orbit.AbstractTrainer`, which represents model
training details
.
eval
uator: An instance of `orbit.AbstractEvaluator`, which represents
model
evaluation
details
.
global_step: An integer `tf.Variable` indicating the global training step
number. Usually this can be obtained from `iterations` property of the
model's optimizer (e.g. `self.optimizer.iterations`), or users can
...
...
@@ -66,105 +70,163 @@ class Controller(object):
own global step variable, it is recommended to create the `tf.Variable`
inside strategy scope, and with
`aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
train_steps: The total (maximum) number of training steps to perform.
steps_per_loop: The number of steps to run in each "inner loop" of
training (passed to the `num_steps` parameter of `train_fn`).
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
training (passed to the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`.
summary_interval: Step interval for training summaries. Note that this
argument only applies to the summaries outside the training loop. If the
value is None, then training summaries are not enabled.
argument only applies to the summaries inside `trainer.train` function.
Summaries outside like "steps_per_second" and outputs from
`trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip evaluation
in the middle of training. Note that evaluation only happens outside the
training loop, which the loop iteration is specify by `steps_per_loop`
parameter.
Raises:
ValueError: If both `train_fn` and `eval_fn` are None.
ValueError: If `train_fn` is not None and `train_steps` is None.
ValueError: If `steps_per_loop` is None when `train_fn` is provided.
ValueError: If both `trainer` and `evaluator` are None.
ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `summary_interval` is not a positive integer or it cannot
be divisible by `steps_per_loop`.
"""
if
train_fn
is
None
and
eval_fn
is
None
:
raise
ValueError
(
"`train_fn` and `eval_fn` should not both be None"
)
# TODO(rxsang): Support training until exhaustion by passing
# `train_steps=-1`. Currently it cannot be supported with a host training
# loop because break statements are not supported with distributed dataset.
if
train_fn
is
not
None
:
if
train_steps
is
None
:
raise
ValueError
(
"`train_steps` is required when `train_fn` is "
"provided."
)
if
trainer
is
None
and
evaluator
is
None
:
raise
ValueError
(
"`trainer` and `evaluator` should not both be None"
)
if
trainer
is
not
None
:
if
steps_per_loop
is
None
:
raise
ValueError
(
"`steps_per_loop` is required when `train
_fn
is "
raise
ValueError
(
"`steps_per_loop` is required when `train
er`
is "
"provided."
)
if
not
isinstance
(
steps_per_loop
,
int
)
or
steps_per_loop
<
1
:
raise
ValueError
(
"`steps_per_loop` should be a positive integer"
)
if
summary_interval
is
not
None
and
summary_interval
<=
0
:
if
summary_interval
is
not
None
:
if
summary_interval
<=
0
:
raise
ValueError
(
"`summary_interval` should be larger than 0"
)
_validate_interval
(
summary_interval
,
steps_per_loop
,
interval_name
=
"summary"
)
self
.
trainer
=
trainer
self
.
evaluator
=
evaluator
self
.
strategy
=
strategy
or
tf
.
distribute
.
get_strategy
()
self
.
train_fn
=
train_fn
self
.
eval_fn
=
eval_fn
self
.
global_step
=
global_step
self
.
checkpoint_manager
=
checkpoint_manager
if
self
.
train_fn
is
not
None
:
self
.
train_steps
=
train_steps
self
.
steps_per_loop
=
steps_per_loop
if
summary_dir
:
self
.
summary_dir
=
summary_dir
elif
checkpoint_manager
:
self
.
summary_dir
=
checkpoint_manager
.
directory
else
:
self
.
summary_dir
=
None
if
summary_dir
is
None
and
checkpoint_manager
:
summary_dir
=
checkpoint_manager
.
directory
if
self
.
trainer
is
not
None
:
self
.
step_timer
=
None
self
.
steps_per_loop
=
steps_per_loop
self
.
summary_interval
=
summary_interval
if
self
.
summary_dir
and
self
.
summary_interval
:
summary_writer
=
tf
.
summary
.
create_file_writer
(
self
.
summary_dir
)
else
:
summary_writer
=
None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self
.
summary_manager
=
utils
.
SummaryManager
(
summary_writer
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
,
summary_interval
=
self
.
summary_interval
)
if
self
.
eval_fn
is
not
None
:
eval_summary_dir
=
eval_summary_dir
or
self
.
summary_dir
eval_summary_writer
=
tf
.
summary
.
create_file_writer
(
eval_summary_dir
)
if
eval_summary_dir
else
None
summary_dir
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
eval_summary_writer
=
None
if
self
.
evaluator
is
not
None
:
eval_summary_dir
=
eval_summary_dir
or
summary_dir
if
eval_summary_dir
==
summary_dir
and
self
.
trainer
is
not
None
:
# Reuse the summary writer if train and evaluation summary directory
# are the same.
self
.
eval_summary_manager
=
self
.
summary_manager
else
:
self
.
eval_summary_manager
=
utils
.
SummaryManager
(
eval_summary_writer
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
self
.
eval_steps
=
eval_steps
self
.
eval_interval
=
eval_interval
# Creates and initializes the interval triggers.
self
.
eval_trigger
=
utils
.
IntervalTrigger
(
self
.
eval_interval
,
self
.
global_step
.
numpy
())
# pytype: disable=attribute-error
eval_summary_dir
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
if
self
.
global_step
:
if
self
.
global_step
is
not
None
:
tf
.
summary
.
experimental
.
set_step
(
self
.
global_step
)
# Restores the model if needed.
# TODO(momernick): We probably only want to do this on certain occasions?
if
self
.
checkpoint_manager
is
not
None
:
model_restored
=
self
.
_restore_model
()
if
not
model_restored
and
self
.
checkpoint_manager
.
checkpoint_interval
:
# If the model is not restored from a checkpoint, save an initial
checkpoint_interval
=
self
.
checkpoint_manager
.
checkpoint_interval
_validate_interval
(
checkpoint_interval
,
steps_per_loop
,
interval_name
=
"checkpoint"
)
model_restored
=
self
.
restore_checkpoint
()
if
not
model_restored
and
(
checkpoint_interval
and
self
.
trainer
is
not
None
):
# If the model is not restored from a checkpoint, and
# `checkpoint_interval` is enabled for training, save an initial
# checkpoint.
ckpt_path
=
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
self
.
global_step
)
logging
.
info
(
"Saved checkpoins in %s"
,
ckpt_path
)
self
.
save_checkpoint
()
def
train
(
self
,
steps
:
int
,
checkpoint_at_completion
:
bool
=
True
):
"""Runs training.
This method calls the `train` method on the Trainable object until the
global step count is equal to `steps`. It will optionally save checkpoints,
if a CheckpointManager was passed to the Controller instance's `__init__`.
Args:
steps: The global step count to train up to.
checkpoint_at_completion: Whether to save a checkpoint when this method
returns. Defaults to True (write the checkpoint). This is always
triggered, regardless of the checkpointing interval.
"""
if
self
.
trainer
is
None
:
raise
ValueError
(
"`self.trainer` is required when calling `train` "
"method."
)
if
self
.
global_step
is
None
:
raise
ValueError
(
"`self.global_step` is required when calling `train` "
"method."
)
# TODO(momernick): Support steps=None or -1 (training to exhaustion).
current_step
=
self
.
global_step
.
numpy
()
# This is an expensive access.
while
current_step
<
steps
:
logging
.
info
(
"Train at step %s of %s"
,
current_step
,
steps
)
# Calculates steps to run for the next train loop.
num_steps
=
min
(
steps
-
current_step
,
self
.
steps_per_loop
)
self
.
_train_n_steps
(
num_steps
)
self
.
_maybe_save_checkpoint
()
current_step
=
self
.
global_step
.
numpy
()
# This is an expensive access.
if
checkpoint_at_completion
:
self
.
save_checkpoint
()
def
evaluate
(
self
,
steps
:
int
=
None
):
"""Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps`
steps, then writes the returned summaries (if any).
Args:
steps: The number of steps to evaluate for.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided.
"""
if
self
.
evaluator
is
None
:
raise
ValueError
(
"`evaluator` must be provided to call `evaluate()` "
"method."
)
steps
=
steps
or
-
1
current_step
=
self
.
global_step
.
numpy
()
if
steps
>
0
:
logging
.
info
(
"Running %s steps of evaluation at train step: %s"
,
steps
,
current_step
)
steps
=
tf
.
convert_to_tensor
(
steps
,
dtype
=
tf
.
int32
)
else
:
logging
.
info
(
"Evaluating at train step: %s"
,
current_step
)
with
self
.
eval_summary_manager
.
summary_writer
.
as_default
():
eval_outputs
=
self
.
evaluator
.
evaluate
(
steps
)
if
eval_outputs
:
eval_outputs
=
tf
.
nest
.
map_structure
(
utils
.
get_value
,
eval_outputs
)
info
=
"step: {} evaluation metric: {}"
.
format
(
current_step
,
eval_outputs
)
_log_info
(
info
)
self
.
eval_summary_manager
.
write_summaries
(
eval_outputs
)
self
.
eval_summary_manager
.
flush
()
def
_
restore_
model
(
self
,
checkpoint_path
=
None
):
def
restore_
checkpoint
(
self
,
checkpoint_path
:
Text
=
None
):
"""Restore or initialize the model.
Args:
...
...
@@ -172,153 +234,164 @@ class Controller(object):
restore. If None, will restore from `self.checkpoint_manager`.
Returns:
True if the latest checkpoint is found or restored. Otherwise False.
The path to the restored checkpoint if a restore happened, or None
if no restore occurred.
"""
with
self
.
strategy
.
scope
():
# Checkpoint restoring should be inside scope. b/139450638
if
checkpoint_path
is
not
None
:
self
.
checkpoint_manager
.
checkpoint
.
restore
(
checkpoint_path
)
return
True
return
checkpoint_path
return
self
.
checkpoint_manager
.
restore_or_initialize
()
def
_evaluate_once
(
self
,
current_step
):
"""Runs the evaluation once."""
logging
.
info
(
"Start evaluation at step: %s"
,
current_step
)
def
save_checkpoint
(
self
):
"""Checkpoint the model.
with
self
.
eval_summary_manager
.
summary_writer
.
as_default
():
eval_outputs
=
self
.
eval_fn
(
self
.
eval_steps
)
This method will write a checkpoint containing the current state of the
model.
if
eval_outputs
:
eval_outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
eval_outputs
)
Raises:
ValueError: if no CheckpointManager was provided to this Controller's
init args.
"""
self
.
_maybe_save_checkpoint
(
force_trigger
=
True
)
info
=
"step: {} evaluation metric: {}"
.
format
(
current_step
,
eval_outputs
)
self
.
_log_info
(
info
)
def
train_and_evaluate
(
self
,
train_steps
:
int
=
None
,
eval_steps
:
int
=
None
,
eval_interval
:
int
=
None
):
"""Train and evaluate in an interleaved manner.
self
.
eval_summary_manager
.
write_summaries
(
eval_outputs
)
self
.
eval_summary_manager
.
flush
()
This method will train the model until the global step count equals
`train_steps`, running an evaluation for `eval_steps` every `eval_interval`
training steps. In addition, this method will run a final evaluation at the
end of the training sequence.
def
_maybe_save_checkpoints
(
self
,
current_step
,
force_trigger
=
False
):
if
self
.
checkpoint_manager
and
self
.
checkpoint_manager
.
checkpoint_interval
:
ckpt_path
=
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
current_step
,
check_interval
=
not
force_trigger
)
if
ckpt_path
is
not
None
:
logging
.
info
(
"Saved checkpoins in %s"
,
ckpt_path
)
Args:
train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If None,
this method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evalutions.
Must be a multiple of the controller's `steps_per_loop` init arg. If
None, evaluation will only be performed after training is complete.
def
_maybe_evaluate
(
self
,
current_step
,
force_trigger
=
False
):
if
self
.
eval_trigger
(
current_step
,
force_trigger
):
self
.
_evaluate_once
(
current_step
)
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
_validate_interval
(
eval_interval
,
self
.
steps_per_loop
,
interval_name
=
"eval"
)
current_step
=
self
.
global_step
.
numpy
()
# This is an expensive access.
eval_interval
=
eval_interval
or
(
train_steps
-
current_step
)
while
current_step
<
train_steps
:
interval
=
min
(
train_steps
-
current_step
,
eval_interval
)
num_steps
=
current_step
+
interval
self
.
train
(
steps
=
num_steps
,
checkpoint_at_completion
=
False
)
self
.
evaluate
(
steps
=
eval_steps
)
current_step
=
self
.
global_step
.
numpy
()
# This is an expensive access.
self
.
save_checkpoint
()
def
evaluate_continuously
(
self
,
steps
:
int
=
None
,
timeout
:
Optional
[
Union
[
int
,
float
]]
=
None
,
timeout_fn
:
Optional
[
Callable
[[],
bool
]]
=
None
):
"""Monitor a directory and evaluate on checkpoints in it.
This method continuously monitors a directory as specified by this
Controller's CheckpointManager init arg and runs evaluation on the
checkpoints found there.
def
_log_info
(
self
,
message
):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging
.
info
(
message
)
print
(
message
)
Args:
steps: The number of steps to run when evaluating.
timeout: The maximum number of seconds to wait between checkpoints. See
tf.train.checkpoints_iterator documentation.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` was not provided as a controller init arg.
def
train
(
self
,
evaluate
=
True
):
"""Runs the training, with optional evaluation.
"""
for
checkpoint_path
in
tf
.
train
.
checkpoints_iterator
(
self
.
checkpoint_manager
.
directory
,
timeout
=
timeout
,
timeout_fn
=
timeout_fn
):
self
.
restore_checkpoint
(
checkpoint_path
)
self
.
evaluate
(
steps
)
This handles evaluation, gathering summaries, and saving checkpoints.
def
_train_n_steps
(
self
,
num_steps
:
int
):
"""Run training for `num_steps`.
It will also write training outputs to summaries if there is any.
Args:
evaluate: A boolean indicates whether to perform evaluation dur
ing
training
.
num_steps: An integer indicates how many steps to run for this train
ing
loop
.
Raises:
RuntimeError: If `global_step` is not updated correctly in `train_fn`.
RuntimeError: If `global_step` is not updated correctly in
`trainer.train`.
"""
if
self
.
train_fn
is
None
:
raise
ValueError
(
"`self.train_fn` is required when calling `train` "
"method."
)
if
self
.
global_step
is
None
:
raise
ValueError
(
"`self.global_step` is required when calling `train` "
"method."
)
if
evaluate
and
self
.
eval_fn
is
None
:
raise
ValueError
(
"`self.eval_fn` is required when calling `train` method "
"with `evaluate=True`"
)
if
not
self
.
step_timer
:
self
.
step_timer
=
StepTimer
(
self
.
global_step
)
step_timer
=
_StepTimer
(
self
.
global_step
)
current_step
=
self
.
global_step
.
numpy
()
logging
.
info
(
"Train at step %s of %s"
,
current_step
,
self
.
train_steps
)
while
current_step
<
self
.
train_steps
:
# Calculates steps to run for the next train loop.
steps_per_loop
=
min
(
self
.
train_steps
-
current_step
,
self
.
steps_per_loop
)
logging
.
info
(
"Entering training loop
with %s steps, at step %s of %
s"
,
steps_per_loop
,
current_step
,
self
.
train
_steps
)
current_step
+=
steps
_per_loop
steps
_per_loop
=
tf
.
convert_to_tensor
(
steps
_per_loop
,
dtype
=
tf
.
int32
)
current_step
=
self
.
global_step
.
numpy
(
)
logging
.
info
(
"Entering training loop
at step %s to run %s step
s"
,
current_step
,
num
_steps
)
current_step
+=
num_
steps
num_
steps
=
tf
.
convert_to_tensor
(
num_
steps
,
dtype
=
tf
.
int32
)
with
self
.
summary_manager
.
summary_writer
.
as_default
():
train_outputs
=
self
.
train_fn
(
steps_per_loop
)
# Create a lambda that returns true when summaries should be written.
should_record
=
False
# Allows static optimization in no-summary cases.
if
self
.
summary_interval
:
should_record
=
lambda
:
(
self
.
global_step
%
self
.
summary_interval
==
0
)
with
tf
.
summary
.
record_if
(
should_record
):
train_outputs
=
self
.
trainer
.
train
(
num_steps
)
# Updates and verifies the current step after a training loop finishes.
if
current_step
!=
self
.
global_step
.
numpy
():
raise
RuntimeError
(
"`
self.train_fn`
is not updating
`global_step`
"
"
correctly, expected: %s, actual: %s"
%
raise
RuntimeError
(
"`
trainer.train` function
is not updating "
"`global_step`
correctly, expected: %s, actual: %s"
%
(
current_step
,
self
.
global_step
.
numpy
()))
# Print information like metrics and steps_per_second after a training
# loop.
if
train_outputs
:
train_outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
train_outputs
)
steps_per_second
=
step_timer
.
steps_per_second
()
train_outputs
=
tf
.
nest
.
map_structure
(
utils
.
get_value
,
train_outputs
)
train_outputs
=
train_outputs
or
{}
steps_per_second
=
self
.
step_timer
.
steps_per_second
()
info
=
"step: {} steps_per_second: {:.2f} {}"
.
format
(
current_step
,
steps_per_second
,
train_outputs
)
self
.
_log_info
(
info
)
_log_info
(
info
)
train_outputs
=
train_outputs
or
{}
train_outputs
[
"steps_per_second"
]
=
steps_per_second
self
.
summary_manager
.
write_summaries
(
train_outputs
)
self
.
_maybe_save_checkpoints
(
current_step
)
if
evaluate
:
self
.
_maybe_evaluate
(
current_step
)
self
.
summary_manager
.
write_summaries
(
train_outputs
,
always_write
=
True
)
self
.
summary_manager
.
flush
()
self
.
_maybe_save_checkpoints
(
current_step
,
force_trigger
=
True
)
if
evaluate
:
self
.
_maybe_evaluate
(
current_step
,
force_trigger
=
True
)
def
evaluate
(
self
,
continuous
=
False
,
timeout_fn
=
None
):
"""Runs the evaluation.
def
_maybe_save_checkpoint
(
self
,
force_trigger
:
bool
=
False
):
"""Save checkpoints if necessary.
Args:
continuous: If `True`, will continously monitor the checkpoint directory
to evaluate on the latest checkpoint. If `False`, will do the evaluation
once.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
force_trigger: A boolean indicates whether to force saving checkpoints
regardless of the checkpoint interval.
R
aise
s:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`
.
R
eturn
s:
A boolean indicating whether a checkpoint was saved
.
"""
if
self
.
eval_fn
is
None
:
raise
ValueError
(
"`self.eval_fn` should not be None to call "
"`evaluate()` method."
)
if
not
continuous
and
timeout_fn
is
not
None
:
raise
ValueError
(
"`timeout_fn` can be only passed when `continuous` is "
"True"
)
if
continuous
:
for
checkpoint_path
in
tf
.
train
.
checkpoints_iterator
(
self
.
checkpoint_manager
.
directory
,
timeout_fn
=
timeout_fn
):
self
.
_restore_model
(
checkpoint_path
)
self
.
_evaluate_once
(
self
.
global_step
.
numpy
())
return
latest_checkpoint
=
self
.
checkpoint_manager
.
latest_checkpoint
if
not
latest_checkpoint
:
raise
ValueError
(
"no checkpoint found in dir %s"
%
self
.
checkpoint_manager
.
directory
)
self
.
_restore_model
()
self
.
_evaluate_once
(
self
.
global_step
.
numpy
())
if
self
.
checkpoint_manager
and
self
.
checkpoint_manager
.
checkpoint_interval
:
ckpt_path
=
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
self
.
global_step
.
numpy
(),
check_interval
=
not
force_trigger
)
if
ckpt_path
is
not
None
:
logging
.
info
(
"Saved checkpoints in %s"
,
ckpt_path
)
return
True
return
False
class
_
StepTimer
(
object
)
:
class
StepTimer
:
"""Utility class for measuring steps/second."""
def
__init__
(
self
,
step
):
...
...
o
fficial/staging/training
/controller_test.py
→
o
rbit
/controller_test.py
View file @
31ca3b97
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,35 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.staging.training.controller."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for orbit.controller."""
import
os
from
absl
import
logging
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
orbit
import
controller
from
orbit
import
standard_runner
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.staging.training
import
controller
from
official.staging.training
import
standard_runnable
def
all_strategy_combinations
():
"""Gets combinations of distribution strategies."""
return
combinations
.
combine
(
strategy
=
[
strategy_combinations
.
one_device_strategy
,
strategy_combinations
.
tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
strategy_combinations
.
mirrored_strategy_with_gpu_and_cpu
,
],
mode
=
"eager"
,
)
import
tensorflow
as
tf
def
create_model
():
...
...
@@ -57,7 +39,7 @@ def summaries_with_matching_keyword(keyword, summary_dir):
if
event
.
summary
is
not
None
:
for
value
in
event
.
summary
.
value
:
if
keyword
in
value
.
tag
:
tf
.
compat
.
v1
.
logging
.
error
(
event
)
logging
.
info
(
event
)
yield
event
.
summary
...
...
@@ -69,30 +51,33 @@ def check_eventfile_for_keyword(keyword, summary_dir):
def
dataset_fn
(
ctx
):
del
ctx
inputs
=
np
.
zeros
((
10
,
3
),
dtype
=
np
.
float32
)
targets
=
np
.
zero
s
((
10
,
4
),
dtype
=
np
.
float32
)
targets
=
np
.
one
s
((
10
,
4
),
dtype
=
np
.
float32
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
inputs
,
targets
))
dataset
=
dataset
.
repeat
(
100
)
dataset
=
dataset
.
batch
(
10
,
drop_remainder
=
True
)
return
dataset
class
TestRunn
abl
e
(
standard_runn
abl
e
.
StandardTrain
abl
e
,
standard_runn
abl
e
.
StandardEvalua
ble
):
class
TestRunne
r
(
standard_runne
r
.
StandardTraine
r
,
standard_runne
r
.
StandardEvalua
tor
):
"""Implements the training and evaluation APIs for the test model."""
def
__init__
(
self
):
standard_runnable
.
StandardTrainable
.
__init__
(
self
)
standard_runnable
.
StandardEvaluable
.
__init__
(
self
)
def
__init__
(
self
,
return_numpy
=
False
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
model
=
create_model
()
self
.
optimizer
=
tf
.
keras
.
optimizers
.
RMSprop
()
self
.
optimizer
=
tf
.
keras
.
optimizers
.
RMSprop
(
learning_rate
=
0.1
)
self
.
global_step
=
self
.
optimizer
.
iterations
self
.
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
"train_loss"
,
dtype
=
tf
.
float32
)
self
.
eval_loss
=
tf
.
keras
.
metrics
.
Mean
(
"eval_loss"
,
dtype
=
tf
.
float32
)
def
build_train_dataset
(
self
):
return
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
self
.
return_numpy
=
return_numpy
train_dataset
=
(
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
eval_dataset
=
(
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
standard_runner
.
StandardTrainer
.
__init__
(
self
,
train_dataset
)
standard_runner
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
)
def
train_step
(
self
,
iterator
):
...
...
@@ -101,7 +86,7 @@ class TestRunnable(standard_runnable.StandardTrainable,
inputs
,
targets
=
inputs
with
tf
.
GradientTape
()
as
tape
:
outputs
=
self
.
model
(
inputs
)
loss
=
tf
.
math
.
reduce_
sum
(
outputs
-
targets
)
loss
=
tf
.
reduce_
mean
(
tf
.
keras
.
losses
.
MSE
(
targets
,
outputs
)
)
grads
=
tape
.
gradient
(
loss
,
self
.
model
.
variables
)
self
.
optimizer
.
apply_gradients
(
zip
(
grads
,
self
.
model
.
variables
))
self
.
train_loss
.
update_state
(
loss
)
...
...
@@ -109,8 +94,9 @@ class TestRunnable(standard_runnable.StandardTrainable,
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
def
train_loop_end
(
self
):
train_loss
=
self
.
train_loss
.
result
()
return
{
"loss"
:
self
.
train_loss
.
result
()
,
"loss"
:
train_loss
.
numpy
()
if
self
.
return_numpy
else
train_loss
,
}
def
build_eval_dataset
(
self
):
...
...
@@ -126,39 +112,110 @@ class TestRunnable(standard_runnable.StandardTrainable,
"""Replicated evaluation step."""
inputs
,
targets
=
inputs
outputs
=
self
.
model
(
inputs
)
loss
=
tf
.
math
.
reduce_
sum
(
outputs
-
targets
)
loss
=
tf
.
reduce_
mean
(
tf
.
keras
.
losses
.
MSE
(
targets
,
outputs
)
)
self
.
eval_loss
.
update_state
(
loss
)
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
def
eval_end
(
self
):
eval_loss
=
self
.
eval_loss
.
result
()
return
{
"eval_loss"
:
eval_loss
.
numpy
()
if
self
.
return_numpy
else
eval_loss
,
}
class
TestEvaluator
(
standard_runner
.
StandardEvaluator
):
"""Implements the training and evaluation APIs for the test model."""
def
__init__
(
self
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
model
=
create_model
()
eval_dataset
=
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
standard_runner
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
)
def
eval_reduce
(
self
,
state
,
output
):
state
.
append
(
output
)
return
state
def
eval_begin
(
self
):
return
[]
def
eval_step
(
self
,
iterator
):
def
_replicated_step
(
inputs
):
"""Replicated evaluation step."""
inputs
,
targets
=
inputs
outputs
=
self
.
model
(
inputs
)
loss
=
tf
.
reduce_mean
(
tf
.
keras
.
losses
.
MSE
(
targets
,
outputs
))
return
loss
per_replica_losses
=
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
mean_loss
=
self
.
strategy
.
reduce
(
tf
.
distribute
.
ReduceOp
.
MEAN
,
per_replica_losses
,
axis
=
None
)
return
mean_loss
def
eval_end
(
self
,
outputs
):
return
{
"eval_loss"
:
self
.
eval_loss
.
result
(
),
"eval_loss"
:
tf
.
reduce_mean
(
outputs
),
}
class
TestTrainerWithSummaries
(
standard_runner
.
StandardTrainer
):
"""A Trainer model with summaries for testing purposes."""
def
__init__
(
self
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
model
=
create_model
()
self
.
optimizer
=
tf
.
keras
.
optimizers
.
RMSprop
(
learning_rate
=
0.1
)
self
.
global_step
=
self
.
optimizer
.
iterations
self
.
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
"train_loss"
,
dtype
=
tf
.
float32
)
train_dataset
=
(
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
standard_runner
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
use_tpu_summary_optimization
=
True
)
def
build_train_dataset
(
self
):
return
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
def
train_step
(
self
,
iterator
):
def
_replicated_step
(
inputs
):
"""Replicated training step."""
inputs
,
targets
=
inputs
with
tf
.
GradientTape
()
as
tape
:
outputs
=
self
.
model
(
inputs
)
loss
=
tf
.
reduce_mean
(
tf
.
keras
.
losses
.
MSE
(
targets
,
outputs
))
tf
.
summary
.
scalar
(
"loss"
,
loss
)
grads
=
tape
.
gradient
(
loss
,
self
.
model
.
variables
)
self
.
optimizer
.
apply_gradients
(
zip
(
grads
,
self
.
model
.
variables
))
self
.
train_loss
.
update_state
(
loss
)
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
class
ControllerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
ControllerTest
,
self
).
setUp
()
super
().
setUp
()
self
.
model_dir
=
self
.
get_temp_dir
()
def
test_no_checkpoint
(
self
):
test_runn
abl
e
=
TestRunn
abl
e
()
test_runne
r
=
TestRunne
r
()
# No checkpoint manager and no strategy.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
...
...
@@ -171,51 +228,46 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
# No checkpoint, so global step starts from 0.
test_runnable
.
global_step
.
assign
(
0
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
test_runner
.
global_step
.
assign
(
0
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
def
test_no_checkpoint_and_summaries
(
self
):
test_runn
abl
e
=
TestRunn
abl
e
()
test_runne
r
=
TestRunne
r
()
# No checkpoint + summary directories.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_and_evaluate
(
self
,
strategy
):
with
strategy
.
scope
():
test_runnable
=
TestRunnable
()
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
@
parameterized
.
named_parameters
((
"return_numpy"
,
True
),
(
"return_tensor"
,
False
))
def
test_train_and_evaluate
(
self
,
return_numpy
):
test_runner
=
TestRunner
(
return_numpy
=
return_numpy
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runn
abl
e
.
model
,
optimizer
=
test_runn
abl
e
.
optimizer
)
model
=
test_runne
r
.
model
,
optimizer
=
test_runne
r
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runn
abl
e
.
global_step
,
step_counter
=
test_runne
r
.
global_step
,
checkpoint_interval
=
10
)
test_controller
=
controller
.
Controller
(
strategy
=
strategy
,
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
# Checkpoints are saved.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt*"
)))
...
...
@@ -232,31 +284,26 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_only
(
self
,
strategy
):
with
strategy
.
scope
():
test_runnable
=
TestRunnable
()
def
test_train_only
(
self
):
test_runner
=
TestRunner
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runn
abl
e
.
model
,
optimizer
=
test_runn
abl
e
.
optimizer
)
model
=
test_runne
r
.
model
,
optimizer
=
test_runne
r
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runn
abl
e
.
global_step
,
step_counter
=
test_runne
r
.
global_step
,
checkpoint_interval
=
10
)
test_controller
=
controller
.
Controller
(
strategy
=
strategy
,
train_fn
=
test_runnable
.
train
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
trainer
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
)
test_controller
.
train
(
evaluate
=
False
)
test_controller
.
train
(
steps
=
10
)
# Checkpoints are saved.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt*"
)))
...
...
@@ -270,29 +317,23 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_evaluate_only
(
self
,
strategy
):
with
strategy
.
scope
():
test_runnable
=
TestRunnable
()
def
test_evaluate_only
(
self
):
test_runner
=
TestRunner
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runn
abl
e
.
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runne
r
.
model
)
checkpoint
.
save
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt"
))
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runn
abl
e
.
global_step
)
step_counter
=
test_runne
r
.
global_step
)
test_controller
=
controller
.
Controller
(
strategy
=
strategy
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
evaluate
()
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
test_controller
.
evaluate
(
steps
=
2
)
# Only eval summaries are written
self
.
assertFalse
(
...
...
@@ -303,6 +344,207 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
# Tests continuous eval with timeout and timeout_fn.
done_file
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval/Done"
)
def
timeout_fn
():
with
tf
.
io
.
gfile
.
GFile
(
done_file
,
"w"
)
as
f
:
f
.
write
(
"DONE"
)
return
True
test_controller
=
controller
.
Controller
(
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
test_controller
.
evaluate_continuously
(
timeout
=
1
,
timeout_fn
=
timeout_fn
,
steps
=
2
)
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
done_file
))
def
test_no_eval_steps
(
self
):
test_runner
=
TestRunner
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
)
checkpoint
.
save
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt"
))
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
)
test_controller
=
controller
.
Controller
(
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
checkpoint_manager
=
checkpoint_manager
)
test_controller
.
evaluate
()
def
test_already_trained_model
(
self
):
test_runner
=
TestRunner
()
test_runner
.
global_step
.
assign
(
10
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
,
optimizer
=
test_runner
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
,
checkpoint_interval
=
10
)
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
,
checkpoint_manager
=
checkpoint_manager
)
# `global_step` is already `train_steps`.
test_controller
.
train
(
steps
=
10
)
def
test_summaries_inside_train_fn
(
self
):
test_runner
=
TestTrainerWithSummaries
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
,
optimizer
=
test_runner
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
)
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
checkpoint_manager
=
checkpoint_manager
,
)
test_controller
.
train
(
steps
=
10
)
# Checkpoints are saved.
self
.
assertEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt*"
)))
# Only train summaries are written.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
def
test_train_and_evaluate_with_same_summary_dir
(
self
):
test_runner
=
TestRunner
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
,
optimizer
=
test_runner
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
)
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
),
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
))
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
def
test_early_stop_on_eval_loss
(
self
):
test_runner
=
TestRunner
()
class
EarlyStopController
(
controller
.
Controller
):
"""A subclass of Controller supports early stopping."""
def
train_and_evaluate
(
self
,
train_steps
:
int
=
None
,
eval_steps
:
int
=
None
,
eval_interval
:
int
=
None
):
while
self
.
global_step
.
numpy
()
<
train_steps
:
interval
=
min
(
train_steps
-
self
.
global_step
.
numpy
(),
eval_interval
)
num_steps
=
self
.
global_step
.
numpy
()
+
interval
self
.
train
(
steps
=
num_steps
,
checkpoint_at_completion
=
False
)
self
.
evaluate
(
steps
=
eval_steps
)
# Early stop condition.
if
test_runner
.
eval_loss
.
result
()
<
0.1
:
logging
.
info
(
"Training early stopped as eval_loss %s is less than 0.1"
,
test_runner
.
eval_loss
.
result
())
return
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
,
optimizer
=
test_runner
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
,
checkpoint_interval
=
10
)
test_controller
=
EarlyStopController
(
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
,
checkpoint_manager
=
checkpoint_manager
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
6
,
eval_interval
=
2
)
self
.
assertLess
(
test_runner
.
global_step
,
10
)
def
test_evaluate_with_loss_outputs
(
self
):
test_evaluator
=
TestEvaluator
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_evaluator
.
model
)
checkpoint
.
save
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt"
))
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
)
test_controller
=
controller
.
Controller
(
evaluator
=
test_evaluator
,
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
),
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
test_controller
.
evaluate
(
steps
=
5
)
# Only eval summaries are written
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
def
test_train_and_evaluate_reset_datasets
(
self
):
test_runner
=
TestRunner
()
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
2
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
train_dataset
=
(
test_runner
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
))
eval_dataset
=
(
test_runner
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
))
test_runner
.
train_dataset
=
train_dataset
test_runner
.
eval_dataset
=
eval_dataset
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
o
fficial/staging/training
/runn
abl
e.py
→
o
rbit
/runne
r
.py
View file @
31ca3b97
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,19 +15,12 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
six
import
tensorflow.compat.v2
as
tf
from
typing
import
Dict
,
Optional
,
Text
import
tensorflow
as
tf
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractTrainable
(
tf
.
Module
):
class
AbstractTrainer
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining the APIs required for training."""
@
abc
.
abstractmethod
...
...
@@ -50,14 +44,13 @@ class AbstractTrainable(tf.Module):
one update to model parameters, e.g. if training a GAN).
Returns:
The function may return a dictionary of `Tensors`
, which will be
written to logs and as TensorBoard summaries.
The function may return a dictionary of `Tensors`
or numpy arrays, which
will be
written to logs and as TensorBoard summaries.
"""
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractEvaluable
(
tf
.
Module
):
class
AbstractEvaluator
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining the APIs required for evaluation."""
@
abc
.
abstractmethod
...
...
@@ -73,7 +66,7 @@ class AbstractEvaluable(tf.Module):
is `None`.
Returns:
The function may return a dictionary of `Tensors`
, which will be
written to logs and as TensorBoard summaries.
The function may return a dictionary of `Tensors`
or numpy arrays, which
will be
written to logs and as TensorBoard summaries.
"""
pass
o
fficial/staging/training
/standard_runn
abl
e.py
→
o
rbit
/standard_runne
r
.py
View file @
31ca3b97
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,67 +15,79 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
six
import
tensorflow.compat.v2
as
tf
from
typing
import
Dict
,
Optional
,
Text
from
typing
import
Any
,
Dict
,
Optional
,
Text
from
orbit
import
runner
from
orbit
import
utils
import
tensorflow
as
tf
from
official.staging.training
import
runnable
from
official.staging.training
import
utils
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractTrainer APIs."""
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardTrainable
(
runnable
.
AbstractTrainable
):
"""Implements the standard functionality of AbstractTrainable APIs."""
def
__init__
(
self
,
train_dataset
,
use_tf_while_loop
=
True
,
use_tf_function
=
True
,
use_tpu_summary_optimization
=
False
):
"""Construct a `StandardTrainer` object.
def
__init__
(
self
,
use_tf_while_loop
=
True
,
use_tf_function
=
True
):
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
if
use_tf_while_loop
and
not
use_tf_function
:
raise
ValueError
(
"`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported"
)
self
.
use_tf_while_loop
=
use_tf_while_loop
self
.
use_tf_function
=
use_tf_function
self
.
train_dataset
=
None
self
.
train_iter
=
None
self
.
train_loop_fn
=
None
@
abc
.
abstractmethod
def
build_train_dataset
(
self
):
"""Builds the training datasets.
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
"""
pass
if
use_tpu_summary_optimization
and
not
use_tf_while_loop
:
raise
ValueError
(
"`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported"
)
self
.
_use_tf_while_loop
=
use_tf_while_loop
self
.
_use_tf_function
=
use_tf_function
self
.
_train_dataset
=
train_dataset
self
.
_train_iter
=
None
self
.
_train_loop_fn
=
None
self
.
_use_tpu_summary_optimization
=
use_tpu_summary_optimization
def
train
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""See base class."""
if
self
.
train_
dataset
is
None
:
# Build train input dataset
self
.
train_
dataset
=
self
.
build_train_dataset
()
self
.
train_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
train_dataset
)
self
.
train_
loop_begin
()
if
self
.
_
train_
iter
is
None
:
self
.
_
train_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
train_dataset
)
if
self
.
train_loop_fn
is
None
:
if
self
.
_
train_loop_fn
is
None
:
train_fn
=
self
.
train_step
if
self
.
use_tf_while_loop
:
self
.
train_loop_fn
=
utils
.
create_tf_while_loop_fn
(
train_fn
)
if
self
.
_use_tf_while_loop
:
self
.
_train_loop_fn
=
utils
.
create_tf_while_loop_fn
(
train_fn
)
if
self
.
_use_tpu_summary_optimization
:
self
.
_train_loop_fn
=
utils
.
train_function_with_summaries
(
self
.
_train_loop_fn
)
else
:
if
self
.
use_tf_function
:
self
.
_train_loop_fn
=
tf
.
function
(
self
.
_train_loop_fn
)
else
:
if
self
.
_use_tf_function
:
train_fn
=
tf
.
function
(
train_fn
)
self
.
train_loop_fn
=
utils
.
create_loop_fn
(
train_fn
)
self
.
_
train_loop_fn
=
utils
.
create_loop_fn
(
train_fn
)
self
.
train_loop_begin
()
self
.
train_loop_fn
(
self
.
train_iter
,
num_steps
)
self
.
_train_loop_fn
(
self
.
_train_iter
,
num_steps
)
return
self
.
train_loop_end
()
def
train_loop_begin
(
self
):
"""Called once at the beginning of the training loop.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over multiple
steps of training.
"""
...
...
@@ -107,54 +120,74 @@ class StandardTrainable(runnable.AbstractTrainable):
"""
pass
@
property
def
train_dataset
(
self
):
"""Returns the train_dataset instance."""
return
self
.
_train_dataset
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardEvaluable
(
runnable
.
AbstractEvaluable
):
"""
Implements the standard functionality of AbstractEvaluable APIs."""
@
train_dataset
.
setter
def
train_dataset
(
self
,
train_dataset
):
"""
Set a new train dataset and replace with the existing one.
def
__init__
(
self
,
use_tf_function
=
True
):
self
.
eval_use_tf_function
=
use_tf_function
self
.
eval_dataset
=
None
self
.
eval_loop_fn
=
None
Any unfinished work in the previous dataset will be discarded.
@
abc
.
abstractmethod
def
build_eval_dataset
(
self
):
"""Builds the evaluation datasets.
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
"""
self
.
_train_dataset
=
train_dataset
self
.
_train_iter
=
None
Returns:
A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def
__init__
(
self
,
eval_dataset
,
use_tf_function
=
True
):
"""Construct a `StandardEvaluator` object.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, evaluation will run on pure eager mode.
"""
pass
self
.
_eval_use_tf_function
=
use_tf_function
self
.
_eval_dataset
=
eval_dataset
self
.
_eval_loop_fn
=
None
def
evaluate
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""See base class."""
if
self
.
eval_dataset
is
None
:
# Build train input dataset
self
.
eval_dataset
=
self
.
build_eval_dataset
()
outputs
=
self
.
eval_begin
()
# pylint: disable=assignment-from-no-return
if
self
.
eval_loop_fn
is
None
:
eval_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
_eval_dataset
)
if
self
.
_eval_loop_fn
is
None
:
eval_fn
=
self
.
eval_step
if
self
.
eval_use_tf_function
:
if
self
.
_
eval_use_tf_function
:
eval_fn
=
tf
.
function
(
eval_fn
)
self
.
eval_loop_fn
=
utils
.
create_loop_fn
(
eval_fn
)
self
.
_
eval_loop_fn
=
utils
.
create_loop_fn
(
eval_fn
)
eval_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_dataset
)
self
.
eval_begin
()
self
.
eval_loop_fn
(
eval_iter
,
num_steps
)
outputs
=
self
.
_eval_loop_fn
(
eval_iter
,
num_steps
,
state
=
outputs
,
reduce_fn
=
self
.
eval_reduce
)
if
outputs
is
None
:
return
self
.
eval_end
()
else
:
return
self
.
eval_end
(
outputs
)
def
eval_begin
(
self
):
def
eval_begin
(
self
)
->
Any
:
"""Called once at the beginning of the evaluation.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over the entire
evaluation.
Returns:
An output which is passed as `state` argument into `eval_reduce` function.
"""
pass
@
abc
.
abstractmethod
def
eval_step
(
self
,
iterator
):
def
eval_step
(
self
,
iterator
)
->
Any
:
"""Implements one step of evaluation.
What a "step" consists of is up to the implementer. If using distribution
...
...
@@ -165,17 +198,57 @@ class StandardEvaluable(runnable.AbstractEvaluable):
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
Returns:
An output which is passed as `step_outputs` argument into `eval_reduce`
function.
"""
pass
def
eval_end
(
self
)
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
def
eval_end
(
self
,
*
args
)
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
"""Called at the end of the evaluation.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the evaluate() method.
Args:
*args: the outputs from `eval_reduce` for the last eval step.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
def
eval_reduce
(
self
,
state
=
None
,
step_outputs
=
None
)
->
Any
:
"""A function to do the reduction on the evaluation outputs per step.
This is useful for passing states throughout evaluation. E.g. it can be used
to maintain the output losses from all the evaluation steps, and compute the
mean loss in `eval_end` function.
Args:
state: A maintained state throughout the evaluation.
step_outputs: Outputs from the current evaluation step.
Returns:
An output which is passed as `state` argument into `eval_reduce` function
for the next step. After evaluation is finished, the output from last step
will be passed into `eval_end` function.
"""
pass
@
property
def
eval_dataset
(
self
):
"""Returns the train_datase instance."""
return
self
.
_eval_dataset
@
eval_dataset
.
setter
def
eval_dataset
(
self
,
eval_dataset
):
"""Set a new eval dataset and replace with the existing one.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
"""
self
.
_eval_dataset
=
eval_dataset
orbit/standard_runner_test.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.standard_runner."""
# pylint: disable=g-bad-import-order
from
orbit
import
standard_runner
import
tensorflow
as
tf
def
dataset_fn
(
input_context
=
None
):
del
input_context
def
dummy_data
(
_
):
return
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
float32
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
class
TestRunner
(
standard_runner
.
StandardTrainer
,
standard_runner
.
StandardEvaluator
):
"""Implements the training and evaluation APIs for tests."""
def
__init__
(
self
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
global_step
=
tf
.
Variable
(
0
,
trainable
=
False
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
standard_runner
.
StandardTrainer
.
__init__
(
self
,
train_dataset
=
None
)
standard_runner
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
=
None
)
def
train_loop_begin
(
self
):
self
.
train_dataset
=
(
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
def
train_step
(
self
,
iterator
):
def
_replicated_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
def
train_loop_end
(
self
):
return
self
.
global_step
.
numpy
()
def
eval_begin
(
self
):
self
.
eval_dataset
=
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
def
eval_step
(
self
,
iterator
):
def
_replicated_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
strategy
.
run
(
_replicated_step
,
args
=
(
next
(
iterator
),))
def
eval_end
(
self
):
return
self
.
global_step
.
numpy
()
class
StandardRunnerTest
(
tf
.
test
.
TestCase
):
def
test_train
(
self
):
test_runner
=
TestRunner
()
self
.
assertEqual
(
test_runner
.
train
(
tf
.
convert_to_tensor
(
10
,
dtype
=
tf
.
int32
)),
10
)
def
test_eval
(
self
):
test_runner
=
TestRunner
()
self
.
assertEqual
(
test_runner
.
evaluate
(
tf
.
convert_to_tensor
(
10
,
dtype
=
tf
.
int32
)),
10
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
o
fficial/staging/training
/utils.py
→
o
rbit
/utils.py
View file @
31ca3b97
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,16 +15,13 @@
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
contextlib
import
functools
import
inspect
import
six
import
tensorflow.compat.v2
as
tf
import
numpy
as
np
import
tensorflow
as
tf
def
create_loop_fn
(
step_fn
):
...
...
@@ -79,7 +77,6 @@ def create_tf_while_loop_fn(step_fn):
A callable defined as the `loop_fn` defination below.
"""
@
tf
.
function
def
loop_fn
(
iterator
,
num_steps
):
"""A loop function with multiple steps.
...
...
@@ -130,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if
six
.
PY3
:
argspec
=
inspect
.
getfullargspec
(
dataset_or_fn
)
else
:
argspec
=
inspect
.
getargspec
(
dataset_or_fn
)
args_names
=
argspec
.
args
if
"input_context"
in
args_names
:
...
...
@@ -144,96 +138,62 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
class
SummaryManager
(
object
)
:
class
SummaryManager
:
"""A class manages writing summaries."""
def
__init__
(
self
,
summary_writer
,
summary_fn
,
global_step
=
None
,
summary_interval
=
None
):
def
__init__
(
self
,
summary_dir
,
summary_fn
,
global_step
=
None
):
"""Construct a summary manager object.
Args:
summary_writer: A `tf.summary.SummaryWriter` instance for writing
summaries.
summary_dir: the directory to write summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for checking the current global step
value, in case users want to save summaries every N steps.
summary_interval: An integer, indicates the minimum step interval between
two summaries.
global_step: A `tf.Variable` instance for the global step.
"""
if
summary_writer
is
not
None
:
self
.
_summary_writer
=
summary_writer
self
.
_enabled
=
True
else
:
self
.
_summary_writer
=
tf
.
summary
.
create_noop_writer
()
self
.
_enabled
=
False
self
.
_enabled
=
(
summary_dir
is
not
None
)
self
.
_summary_dir
=
summary_dir
self
.
_summary_fn
=
summary_fn
self
.
_summary_writer
=
None
if
global_step
is
None
:
self
.
_global_step
=
tf
.
summary
.
experimental
.
get_step
()
else
:
self
.
_global_step
=
global_step
if
summary_interval
is
not
None
:
if
self
.
_global_step
is
None
:
raise
ValueError
(
"`summary_interval` is not None, but no `global_step` "
"can be obtained "
)
self
.
_last_summary_step
=
self
.
_global_step
.
numpy
()
self
.
_summary_interval
=
summary_interval
@
property
def
summary_interval
(
self
):
return
self
.
_summary_interval
@
property
def
summary_writer
(
self
):
"""Returns the underlying summary writer."""
if
self
.
_summary_writer
is
not
None
:
return
self
.
_summary_writer
if
self
.
_enabled
:
self
.
_summary_writer
=
tf
.
summary
.
create_file_writer
(
self
.
_summary_dir
)
else
:
self
.
_summary_writer
=
tf
.
summary
.
create_noop_writer
()
return
self
.
_summary_writer
def
flush
(
self
):
"""Flush the underlying summary writer."""
if
self
.
_enabled
:
tf
.
summary
.
flush
(
self
.
_
summary_writer
)
tf
.
summary
.
flush
(
self
.
summary_writer
)
def
write_summaries
(
self
,
items
,
always_write
=
True
):
def
write_summaries
(
self
,
items
):
"""Write a bulk of summaries.
Args:
items: a dictionary of `Tensors` for writing summaries.
always_write: An optional boolean. If `True`, the manager will always
write summaries unless the summaries have been written for the same
step. Otherwise the manager will only write the summaries if the
interval between summaries are larger than `summary_interval`.
Returns:
A boolean indicates whether the summaries are written or not.
"""
# TODO(rxsang): Support writing summaries with nested structure, so users
# can split the summaries into different directories for nicer visualization
# in Tensorboard, like train and eval metrics.
if
not
self
.
_enabled
:
return
False
if
self
.
_summary_interval
is
not
None
:
current_step
=
self
.
_global_step
.
numpy
()
if
current_step
==
self
.
_last_summary_step
:
return
False
if
not
always_write
and
current_step
<
(
self
.
_last_summary_step
+
self
.
_summary_interval
):
return
False
self
.
_last_summary_step
=
current_step
return
with
self
.
_
summary_writer
.
as_default
():
with
self
.
summary_writer
.
as_default
():
for
name
,
tensor
in
items
.
items
():
self
.
_summary_fn
(
name
,
tensor
,
step
=
self
.
_global_step
)
return
True
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
Trigger
(
object
):
class
Trigger
(
metaclass
=
abc
.
ABCMeta
):
"""An abstract class representing a "trigger" for some event."""
@
abc
.
abstractmethod
...
...
@@ -294,7 +254,7 @@ class IntervalTrigger(Trigger):
self
.
_last_trigger_value
=
0
class
EpochHelper
(
object
)
:
class
EpochHelper
:
"""A Helper class to handle epochs in Customized Training Loop."""
def
__init__
(
self
,
epoch_steps
,
global_step
):
...
...
@@ -340,3 +300,86 @@ class EpochHelper(object):
@
property
def
current_epoch
(
self
):
return
self
.
_current_epoch
@
contextlib
.
contextmanager
def
_soft_device_placement
():
"""Context manager for soft device placement, allowing summaries on CPU."""
original_setting
=
tf
.
config
.
get_soft_device_placement
()
try
:
tf
.
config
.
set_soft_device_placement
(
True
)
yield
finally
:
tf
.
config
.
set_soft_device_placement
(
original_setting
)
def
train_function_with_summaries
(
*
args
,
**
kwargs
):
"""Utility function to support TPU summaries via multiple `tf.function`s.
This permits interleaving summaries inside TPU-compatible code, but without
any performance impact on steps that do not write summaries.
Usage is as a decorator, similar to `tf.function`, and any `tf.function`
arguments will be passed through if supplied:
@trainer.train_function_with_summaries
def train(self, num_steps):
...
The decorated function is assumed to be a loop method accepting a `num_steps`
parameter, as for instance would be called within the `Controller`'s outer
train loop. The implementation here assumes that `summary_frequency` is
divisible by `steps_per_loop`. The decorated method should accept two
arguments, `self` and `num_steps`.
Two `tf.function` versions of `train_fn` are created: one inside a summary
writer scope with soft device placement enabled (used on steps that require
summary writing), and one with no summary writer present and soft device
placement disabled (used on all other steps).
Args:
*args: Arguments to pass through to `tf.function`.
**kwargs: Keyword arguments to pass through to `tf.function`.
Returns:
If the first argument is a callable, returns the decorated callable.
Otherwise, returns a decorator.
"""
def
decorator
(
train_fn
):
# TODO(dhr): Validate the signature of train_fn?
train_fn_with_summaries
=
tf
.
function
(
train_fn
,
*
args
,
**
kwargs
)
train_fn_without_summaries
=
tf
.
function
(
train_fn
,
*
args
,
**
kwargs
)
@
functools
.
wraps
(
train_fn
)
def
wrapper
(
self
,
num_steps
):
if
tf
.
summary
.
should_record_summaries
():
with
_soft_device_placement
():
output
=
train_fn_with_summaries
(
self
,
tf
.
constant
(
1
))
num_steps
-=
1
if
num_steps
>=
1
:
with
tf
.
summary
.
record_if
(
False
):
output
=
train_fn_without_summaries
(
self
,
num_steps
)
return
output
return
wrapper
if
args
and
callable
(
args
[
0
]):
train_fn
,
args
=
args
[
0
],
args
[
1
:]
return
decorator
(
train_fn
)
return
decorator
def
get_value
(
x
)
->
np
.
ndarray
:
"""Returns the value of a variable/tensor.
Args:
x: input variable.
Returns:
A Numpy array.
"""
if
not
tf
.
is_tensor
(
x
):
return
x
return
x
.
numpy
()
research/attention_ocr/README.md
View file @
31ca3b97
#
# Attention-based Extraction of Structured Information from Street View Imagery
# Attention-based Extraction of Structured Information from Street View Imagery
[

](https://paperswithcode.com/sota/optical-character-recognition-on-fsns-test?p=attention-based-extraction-of-structured)
[

](https://arxiv.org/abs/1704.03549)
...
...
@@ -7,14 +7,20 @@
*A TensorFlow model for real-world image text extraction problems.*
This folder contains the code needed to train a new Attention OCR model on the
[
FSNS dataset
][
FSNS
]
dataset to transcribe street names in France. You can
also use it to train it on your own data.
[
FSNS dataset
][
FSNS
]
to transcribe street names in France. You can also train the code on your own data.
More details can be found in our paper:
[
"Attention-based Extraction of Structured Information from Street View
Imagery"
](
https://arxiv.org/abs/1704.03549
)
## Description
*
Paper presents a model based on ConvNets, RNN's and a novel attention mechanism.
Achieves
**84.2%**
on FSNS beating the previous benchmark (
**72.46%**
). Also studies
the speed/accuracy tradeoff that results from using CNN feature extractors of
different depths.
## Contacts
Authors
...
...
@@ -22,7 +28,18 @@ Authors
*
Zbigniew Wojna (zbigniewwojna@gmail.com)
*
Alexander Gorban (gorban@google.com)
Maintainer: Xavier Gibert
[
@xavigibert
](
https://github.com/xavigibert
)
Maintainer
*
Xavier Gibert (
[
@xavigibert
](
https://github.com/xavigibert
)
)
## Table of Contents
*
[
Requirements
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#requirements
)
*
[
Dataset
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#dataset
)
*
[
How to use this code
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-this-code
)
*
[
Using your own image data
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#using-your-own-image-data
)
*
[
How to use a pre-trained model
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-a-pre-trained-model
)
*
[
Disclaimer
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#disclaimer
)
## Requirements
...
...
@@ -49,6 +66,42 @@ cd ..
[
TF
]:
https://www.tensorflow.org/install/
[
FSNS
]:
https://github.com/tensorflow/models/tree/master/research/street
## Dataset
The French Street Name Signs (FSNS) dataset is split into subsets,
each of which is composed of multiple files. Note that these datasets
are very large. The approximate sizes are:
*
Train: 512 files of 300MB each.
*
Validation: 64 files of 40MB each.
*
Test: 64 files of 50MB each.
*
The datasets download includes a directory
`testdata`
that contains
some small datasets that are big enough to test that models can
actually learn something.
*
Total: around 158GB
The download paths are in the following list:
```
https://download.tensorflow.org/data/fsns-20160927/charset_size=134.txt
https://download.tensorflow.org/data/fsns-20160927/test/test-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/test/test-00063-of-00064
https://download.tensorflow.org/data/fsns-20160927/testdata/arial-32-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/mnist-sample-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/numbers-16-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/train/train-00000-of-00512
...
https://download.tensorflow.org/data/fsns-20160927/train/train-00511-of-00512
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00063-of-00064
```
All URLs are stored in the
[
research/street
](
https://github.com/tensorflow/models/tree/master/research/street
)
repository in the text file
`python/fsns_urls.txt`
.
## How to use this code
To run all unit tests:
...
...
@@ -80,7 +133,7 @@ tar xf attention_ocr_2017_08_09.tar.gz
python train.py --checkpoint=model.ckpt-399731
```
##
How to use
your own image data
to train the model
##
Using
your own image data
You need to define a new dataset. There are two options:
...
...
@@ -166,6 +219,14 @@ implement one in Python or C++.
The recommended way is to use the
[
Serving infrastructure
][
serving
]
.
To export to SavedModel format:
```
python model_export.py \
--checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
```
Alternatively you can:
1.
define a placeholder for images (or use directly an numpy array)
2.
[
create a graph
](
https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/eval.py#L60
)
...
...
@@ -188,7 +249,7 @@ other than a one time experiment please use the [TensorFlow Serving][serving].
[
1
]:
https://github.com/tensorflow/tensorflow/blob/aaf7adc/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
[
2
]:
https://www.tensorflow.org/api_docs/python/tf/contrib/framework/assign_from_checkpoint_fn
[
serving
]:
https://tensorflow.
github.io
/serving/serving_basic
[
serving
]:
https://
www.
tensorflow.
org/tfx
/serving/serving_basic
## Disclaimer
...
...
research/attention_ocr/python/common_flags.py
View file @
31ca3b97
...
...
@@ -14,10 +14,10 @@
# ==============================================================================
"""Define flags are common for both train.py and eval.py scripts."""
import
logging
import
sys
from
tensorflow.python.platform
import
flags
import
logging
import
datasets
import
model
...
...
@@ -35,9 +35,17 @@ logging.basicConfig(
datefmt
=
'%Y-%m-%d %H:%M:%S'
)
_common_flags_defined
=
False
def
define
():
"""Define common flags."""
# yapf: disable
# common_flags.define() may be called multiple times in unit tests.
global
_common_flags_defined
if
_common_flags_defined
:
return
_common_flags_defined
=
True
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'Batch size.'
)
...
...
research/attention_ocr/python/data_provider.py
View file @
31ca3b97
...
...
@@ -56,14 +56,14 @@ def augment_image(image):
Returns:
Distorted Tensor image of the same shape.
"""
with
tf
.
variable_scope
(
'AugmentImage'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'AugmentImage'
):
height
=
image
.
get_shape
().
dims
[
0
].
value
width
=
image
.
get_shape
().
dims
[
1
].
value
# Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin
,
bbox_size
,
_
=
tf
.
image
.
sample_distorted_bounding_box
(
tf
.
shape
(
image
),
image_size
=
tf
.
shape
(
input
=
image
),
bounding_boxes
=
tf
.
zeros
([
0
,
0
,
4
]),
min_object_covered
=
0.8
,
aspect_ratio_range
=
[
0.8
,
1.2
],
...
...
@@ -74,7 +74,7 @@ def augment_image(image):
# Randomly chooses one of the 4 interpolation methods
distorted_image
=
inception_preprocessing
.
apply_with_random_selector
(
distorted_image
,
lambda
x
,
method
:
tf
.
image
.
resize
_images
(
x
,
[
height
,
width
],
method
),
lambda
x
,
method
:
tf
.
image
.
resize
(
x
,
[
height
,
width
],
method
),
num_cases
=
4
)
distorted_image
.
set_shape
([
height
,
width
,
3
])
...
...
@@ -99,9 +99,10 @@ def central_crop(image, crop_size):
Returns:
A tensor of shape [crop_height, crop_width, channels].
"""
with
tf
.
variable_scope
(
'CentralCrop'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'CentralCrop'
):
target_width
,
target_height
=
crop_size
image_height
,
image_width
=
tf
.
shape
(
image
)[
0
],
tf
.
shape
(
image
)[
1
]
image_height
,
image_width
=
tf
.
shape
(
input
=
image
)[
0
],
tf
.
shape
(
input
=
image
)[
1
]
assert_op1
=
tf
.
Assert
(
tf
.
greater_equal
(
image_height
,
target_height
),
[
'image_height < target_height'
,
image_height
,
target_height
])
...
...
@@ -129,7 +130,7 @@ def preprocess_image(image, augment=False, central_crop_size=None,
A float32 tensor of shape [H x W x 3] with RGB values in the required
range.
"""
with
tf
.
variable_scope
(
'PreprocessImage'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'PreprocessImage'
):
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
if
augment
or
central_crop_size
:
if
num_towers
==
1
:
...
...
@@ -144,9 +145,6 @@ def preprocess_image(image, augment=False, central_crop_size=None,
images
=
[
augment_image
(
img
)
for
img
in
images
]
image
=
tf
.
concat
(
images
,
1
)
image
=
tf
.
subtract
(
image
,
0.5
)
image
=
tf
.
multiply
(
image
,
2.5
)
return
image
...
...
@@ -185,7 +183,7 @@ def get_data(dataset,
image_orig
,
augment
,
central_crop_size
,
num_towers
=
dataset
.
num_of_views
)
label_one_hot
=
slim
.
one_hot_encoding
(
label
,
dataset
.
num_char_classes
)
images
,
images_orig
,
labels
,
labels_one_hot
=
(
tf
.
train
.
shuffle_batch
(
images
,
images_orig
,
labels
,
labels_one_hot
=
(
tf
.
compat
.
v1
.
train
.
shuffle_batch
(
[
image
,
image_orig
,
label
,
label_one_hot
],
batch_size
=
batch_size
,
num_threads
=
shuffle_config
.
num_batching_threads
,
...
...
research/attention_ocr/python/datasets/fsns.py
View file @
31ca3b97
...
...
@@ -72,7 +72,7 @@ def read_charset(filename, null_character=u'\u2591'):
"""
pattern
=
re
.
compile
(
r
'(\d+)\t(.+)'
)
charset
=
{}
with
tf
.
gfile
.
GFile
(
filename
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
filename
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
m
=
pattern
.
match
(
line
)
if
m
is
None
:
...
...
@@ -96,9 +96,9 @@ class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler):
self
.
_num_of_views
=
num_of_views
def
tensors_to_item
(
self
,
keys_to_tensors
):
return
tf
.
to_int64
(
return
tf
.
cast
(
self
.
_num_of_views
*
keys_to_tensors
[
self
.
_original_width_key
]
/
keys_to_tensors
[
self
.
_width_key
])
keys_to_tensors
[
self
.
_width_key
]
,
dtype
=
tf
.
int64
)
def
get_split
(
split_name
,
dataset_dir
=
None
,
config
=
None
):
...
...
@@ -133,19 +133,19 @@ def get_split(split_name, dataset_dir=None, config=None):
zero
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int64
)
keys_to_features
=
{
'image/encoded'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/format'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'png'
),
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'png'
),
'image/width'
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
'image/orig_width'
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
,
default_value
=
zero
),
'image/class'
:
tf
.
FixedLenFeature
([
config
[
'max_sequence_length'
]],
tf
.
int64
),
tf
.
io
.
FixedLenFeature
([
config
[
'max_sequence_length'
]],
tf
.
int64
),
'image/unpadded_class'
:
tf
.
VarLenFeature
(
tf
.
int64
),
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'image/text'
:
tf
.
FixedLenFeature
([
1
],
tf
.
string
,
default_value
=
''
),
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
string
,
default_value
=
''
),
}
items_to_handlers
=
{
'image'
:
...
...
@@ -171,12 +171,14 @@ def get_split(split_name, dataset_dir=None, config=None):
config
[
'splits'
][
split_name
][
'pattern'
])
return
slim
.
dataset
.
Dataset
(
data_sources
=
file_pattern
,
reader
=
tf
.
TFRecordReader
,
reader
=
tf
.
compat
.
v1
.
TFRecordReader
,
decoder
=
decoder
,
num_samples
=
config
[
'splits'
][
split_name
][
'size'
],
items_to_descriptions
=
config
[
'items_to_descriptions'
],
# additional parameters for convenience.
charset
=
charset
,
charset_file
=
charset_file
,
image_shape
=
config
[
'image_shape'
],
num_char_classes
=
len
(
charset
),
num_of_views
=
config
[
'num_of_views'
],
max_sequence_length
=
config
[
'max_sequence_length'
],
...
...
research/attention_ocr/python/datasets/fsns_test.py
View file @
31ca3b97
...
...
@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase):
image_tf
,
label_tf
=
provider
.
get
([
'image'
,
'label'
])
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
with
slim
.
queues
.
QueueRunners
(
sess
):
image_np
,
label_np
=
sess
.
run
([
image_tf
,
label_tf
])
...
...
research/attention_ocr/python/datasets/testdata/fsns/download_data.py
View file @
31ca3b97
...
...
@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5
print
(
'Downloading %s ...'
%
URL
)
urllib
.
request
.
urlretrieve
(
URL
,
DST_ORIG
)
print
(
'Writing %d records from %s to %s ...'
%
(
KEEP_NUM_RECORDS
,
DST_ORIG
,
DST
))
print
(
'Writing %d records from %s to %s ...'
%
(
KEEP_NUM_RECORDS
,
DST_ORIG
,
DST
))
with
tf
.
io
.
TFRecordWriter
(
DST
)
as
writer
:
for
raw_record
in
itertools
.
islice
(
tf
.
python_io
.
tf_record_iterator
(
DST_ORIG
),
KEEP_NUM_RECORDS
):
for
raw_record
in
itertools
.
islice
(
tf
.
compat
.
v1
.
python_io
.
tf_record_iterator
(
DST_ORIG
),
KEEP_NUM_RECORDS
):
writer
.
write
(
raw_record
)
research/attention_ocr/python/demo_inference.py
View file @
31ca3b97
...
...
@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for
i
in
range
(
batch_size
):
path
=
file_pattern
%
i
print
(
"Reading %s"
%
path
)
pil_image
=
PIL
.
Image
.
open
(
tf
.
gfile
.
GFile
(
path
,
'rb'
))
pil_image
=
PIL
.
Image
.
open
(
tf
.
io
.
gfile
.
GFile
(
path
,
'rb'
))
images_actual_data
[
i
,
...]
=
np
.
asarray
(
pil_image
)
return
images_actual_data
...
...
@@ -63,7 +63,8 @@ def create_model(batch_size, dataset_name):
num_views
=
dataset
.
num_of_views
,
null_code
=
dataset
.
null_code
,
charset
=
dataset
.
charset
)
raw_images
=
tf
.
placeholder
(
tf
.
uint8
,
shape
=
[
batch_size
,
height
,
width
,
3
])
raw_images
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
uint8
,
shape
=
[
batch_size
,
height
,
width
,
3
])
images
=
tf
.
map_fn
(
data_provider
.
preprocess_image
,
raw_images
,
dtype
=
tf
.
float32
)
endpoints
=
model
.
create_base
(
images
,
labels_one_hot
=
None
)
...
...
@@ -93,4 +94,4 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
tf
.
compat
.
v1
.
app
.
run
()
research/attention_ocr/python/demo_inference_test.py
View file @
31ca3b97
...
...
@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase):
super
(
DemoInferenceTest
,
self
).
setUp
()
for
suffix
in
[
'.meta'
,
'.index'
,
'.data-00000-of-00001'
]:
filename
=
_CHECKPOINT
+
suffix
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
filename
),
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
filename
),
msg
=
'Missing checkpoint file %s. '
'Please download and extract it from %s'
%
(
filename
,
_CHECKPOINT_URL
))
self
.
_batch_size
=
32
tf
.
flags
.
FLAGS
.
dataset_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'datasets/testdata/fsns'
)
tf
.
flags
.
FLAGS
.
dataset_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'datasets/testdata/fsns'
)
def
test_moving_variables_properly_loaded_from_a_checkpoint
(
self
):
batch_size
=
32
...
...
@@ -30,9 +31,9 @@ class DemoInferenceTest(tf.test.TestCase):
images_data
=
demo_inference
.
load_images
(
image_path_pattern
,
batch_size
,
dataset_name
)
tensor_name
=
'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf
=
tf
.
get_default_graph
().
get_tensor_by_name
(
moving_mean_tf
=
tf
.
compat
.
v1
.
get_default_graph
().
get_tensor_by_name
(
tensor_name
+
':0'
)
reader
=
tf
.
train
.
NewCheckpointReader
(
_CHECKPOINT
)
reader
=
tf
.
compat
.
v1
.
train
.
NewCheckpointReader
(
_CHECKPOINT
)
moving_mean_expected
=
reader
.
get_tensor
(
tensor_name
)
session_creator
=
monitored_session
.
ChiefSessionCreator
(
...
...
research/attention_ocr/python/eval.py
View file @
31ca3b97
...
...
@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None,
def
main
(
_
):
if
not
tf
.
gfile
.
E
xists
(
FLAGS
.
eval_log_dir
):
tf
.
gfile
.
M
ake
D
irs
(
FLAGS
.
eval_log_dir
)
if
not
tf
.
io
.
gfile
.
e
xists
(
FLAGS
.
eval_log_dir
):
tf
.
io
.
gfile
.
m
ake
d
irs
(
FLAGS
.
eval_log_dir
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
model
=
common_flags
.
create_model
(
dataset
.
num_char_classes
,
...
...
@@ -62,7 +62,7 @@ def main(_):
eval_ops
=
model
.
create_summaries
(
data
,
endpoints
,
dataset
.
charset
,
is_training
=
False
)
slim
.
get_or_create_global_step
()
session_config
=
tf
.
ConfigProto
(
device_count
=
{
"GPU"
:
0
})
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
device_count
=
{
"GPU"
:
0
})
slim
.
evaluation
.
evaluation_loop
(
master
=
FLAGS
.
master
,
checkpoint_dir
=
FLAGS
.
train_log_dir
,
...
...
research/attention_ocr/python/inception_preprocessing.py
View file @
31ca3b97
...
...
@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases):
The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically.
"""
sel
=
tf
.
random
_
uniform
([],
maxval
=
num_cases
,
dtype
=
tf
.
int32
)
sel
=
tf
.
random
.
uniform
([],
maxval
=
num_cases
,
dtype
=
tf
.
int32
)
# Pass the real x only to one of the func calls.
return
control_flow_ops
.
merge
([
func
(
control_flow_ops
.
switch
(
x
,
tf
.
equal
(
sel
,
case
))[
1
],
case
)
...
...
@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
Raises:
ValueError: if color_ordering not in [0, 3]
"""
with
tf
.
name_scope
(
scope
,
'distort_color'
,
[
image
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distort_color'
,
[
image
]):
if
fast_mode
:
if
color_ordering
==
0
:
image
=
tf
.
image
.
random_brightness
(
image
,
max_delta
=
32.
/
255.
)
...
...
@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image,
Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox
"""
with
tf
.
name_scope
(
scope
,
'distorted_bounding_box_crop'
,
[
image
,
bbox
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distorted_bounding_box_crop'
,
[
image
,
bbox
]):
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
...
...
@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image,
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box
=
tf
.
image
.
sample_distorted_bounding_box
(
tf
.
shape
(
image
),
image_size
=
tf
.
shape
(
input
=
image
),
bounding_boxes
=
bbox
,
min_object_covered
=
min_object_covered
,
aspect_ratio_range
=
aspect_ratio_range
,
...
...
@@ -188,7 +188,7 @@ def preprocess_for_train(image,
Returns:
3-D float Tensor of distorted image used for training with range [-1, 1].
"""
with
tf
.
name_scope
(
scope
,
'distort_image'
,
[
image
,
height
,
width
,
bbox
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'distort_image'
,
[
image
,
height
,
width
,
bbox
]):
if
bbox
is
None
:
bbox
=
tf
.
constant
(
[
0.0
,
0.0
,
1.0
,
1.0
],
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
4
])
...
...
@@ -198,7 +198,7 @@ def preprocess_for_train(image,
# the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box
=
tf
.
image
.
draw_bounding_boxes
(
tf
.
expand_dims
(
image
,
0
),
bbox
)
tf
.
summary
.
image
(
'image_with_bounding_boxes'
,
image_with_box
)
tf
.
compat
.
v1
.
summary
.
image
(
'image_with_bounding_boxes'
,
image_with_box
)
distorted_image
,
distorted_bbox
=
distorted_bounding_box_crop
(
image
,
bbox
)
# Restore the shape since the dynamic slice based upon the bbox_size loses
...
...
@@ -206,7 +206,7 @@ def preprocess_for_train(image,
distorted_image
.
set_shape
([
None
,
None
,
3
])
image_with_distorted_box
=
tf
.
image
.
draw_bounding_boxes
(
tf
.
expand_dims
(
image
,
0
),
distorted_bbox
)
tf
.
summary
.
image
(
'images_with_distorted_bounding_box'
,
tf
.
compat
.
v1
.
summary
.
image
(
'images_with_distorted_bounding_box'
,
image_with_distorted_box
)
# This resizing operation may distort the images because the aspect
...
...
@@ -218,10 +218,10 @@ def preprocess_for_train(image,
num_resize_cases
=
1
if
fast_mode
else
4
distorted_image
=
apply_with_random_selector
(
distorted_image
,
lambda
x
,
method
:
tf
.
image
.
resize
_images
(
x
,
[
height
,
width
],
method
=
method
),
lambda
x
,
method
:
tf
.
image
.
resize
(
x
,
[
height
,
width
],
method
=
method
),
num_cases
=
num_resize_cases
)
tf
.
summary
.
image
(
'cropped_resized_image'
,
tf
.
compat
.
v1
.
summary
.
image
(
'cropped_resized_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
# Randomly flip the image horizontally.
...
...
@@ -233,7 +233,7 @@ def preprocess_for_train(image,
lambda
x
,
ordering
:
distort_color
(
x
,
ordering
,
fast_mode
),
num_cases
=
4
)
tf
.
summary
.
image
(
'final_distorted_image'
,
tf
.
compat
.
v1
.
summary
.
image
(
'final_distorted_image'
,
tf
.
expand_dims
(
distorted_image
,
0
))
distorted_image
=
tf
.
subtract
(
distorted_image
,
0.5
)
distorted_image
=
tf
.
multiply
(
distorted_image
,
2.0
)
...
...
@@ -265,7 +265,7 @@ def preprocess_for_eval(image,
Returns:
3-D float Tensor of prepared image.
"""
with
tf
.
name_scope
(
scope
,
'eval_image'
,
[
image
,
height
,
width
]):
with
tf
.
compat
.
v1
.
name_scope
(
scope
,
'eval_image'
,
[
image
,
height
,
width
]):
if
image
.
dtype
!=
tf
.
float32
:
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
# Crop the central region of the image with an area containing 87.5% of
...
...
@@ -276,8 +276,8 @@ def preprocess_for_eval(image,
if
height
and
width
:
# Resize the image to the specified height and width.
image
=
tf
.
expand_dims
(
image
,
0
)
image
=
tf
.
image
.
resize
_bilinear
(
image
,
[
height
,
width
],
align_corners
=
False
)
image
=
tf
.
image
.
resize
(
image
,
[
height
,
width
],
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
squeeze
(
image
,
[
0
])
image
=
tf
.
subtract
(
image
,
0.5
)
image
=
tf
.
multiply
(
image
,
2.0
)
...
...
research/attention_ocr/python/metrics.py
View file @
31ca3b97
...
...
@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False):
a update_ops for execution and value tensor whose value on evaluation
returns the total character accuracy.
"""
with
tf
.
variable_scope
(
'CharAccuracy'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'CharAccuracy'
):
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
targets
=
tf
.
to_int32
(
targets
)
targets
=
tf
.
cast
(
targets
,
dtype
=
tf
.
int32
)
const_rej_char
=
tf
.
constant
(
rej_char
,
shape
=
targets
.
get_shape
())
weights
=
tf
.
to_float
(
tf
.
not_equal
(
targets
,
const_rej_char
))
correct_chars
=
tf
.
to_float
(
tf
.
equal
(
predictions
,
targets
))
accuracy_per_example
=
tf
.
div
(
tf
.
reduce_sum
(
tf
.
multiply
(
correct_chars
,
weights
),
1
),
tf
.
reduce_sum
(
weights
,
1
))
weights
=
tf
.
cast
(
tf
.
not_equal
(
targets
,
const_rej_char
),
dtype
=
tf
.
float32
)
correct_chars
=
tf
.
cast
(
tf
.
equal
(
predictions
,
targets
),
dtype
=
tf
.
float32
)
accuracy_per_example
=
tf
.
compat
.
v1
.
div
(
tf
.
reduce_sum
(
input_tensor
=
tf
.
multiply
(
correct_chars
,
weights
),
axis
=
1
),
tf
.
reduce_sum
(
input_tensor
=
weights
,
axis
=
1
))
if
streaming
:
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
else
:
return
tf
.
reduce_mean
(
accuracy_per_example
)
return
tf
.
reduce_mean
(
input_tensor
=
accuracy_per_example
)
def
sequence_accuracy
(
predictions
,
targets
,
rej_char
,
streaming
=
False
):
...
...
@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False):
returns the total sequence accuracy.
"""
with
tf
.
variable_scope
(
'SequenceAccuracy'
):
with
tf
.
compat
.
v1
.
variable_scope
(
'SequenceAccuracy'
):
predictions
.
get_shape
().
assert_is_compatible_with
(
targets
.
get_shape
())
targets
=
tf
.
to_int32
(
targets
)
targets
=
tf
.
cast
(
targets
,
dtype
=
tf
.
int32
)
const_rej_char
=
tf
.
constant
(
rej_char
,
shape
=
targets
.
get_shape
(),
dtype
=
tf
.
int32
)
include_mask
=
tf
.
not_equal
(
targets
,
const_rej_char
)
include_predictions
=
tf
.
to_int32
(
tf
.
where
(
include_mask
,
predictions
,
tf
.
zeros_like
(
predictions
)
+
rej_char
))
correct_chars
=
tf
.
to_float
(
tf
.
equal
(
include_predictions
,
targets
))
include_predictions
=
tf
.
cast
(
tf
.
compat
.
v1
.
where
(
include_mask
,
predictions
,
tf
.
zeros_like
(
predictions
)
+
rej_char
),
dtype
=
tf
.
int32
)
correct_chars
=
tf
.
cast
(
tf
.
equal
(
include_predictions
,
targets
),
dtype
=
tf
.
float32
)
correct_chars_counts
=
tf
.
cast
(
tf
.
reduce_sum
(
correct_chars
,
reduction_indice
s
=
[
1
]),
dtype
=
tf
.
int32
)
tf
.
reduce_sum
(
input_tensor
=
correct_chars
,
axi
s
=
[
1
]),
dtype
=
tf
.
int32
)
target_length
=
targets
.
get_shape
().
dims
[
1
].
value
target_chars_counts
=
tf
.
constant
(
target_length
,
shape
=
correct_chars_counts
.
get_shape
())
accuracy_per_example
=
tf
.
to_floa
t
(
tf
.
equal
(
correct_chars_counts
,
target_chars_counts
))
accuracy_per_example
=
tf
.
cas
t
(
tf
.
equal
(
correct_chars_counts
,
target_chars_counts
)
,
dtype
=
tf
.
float32
)
if
streaming
:
return
tf
.
contrib
.
metrics
.
streaming_mean
(
accuracy_per_example
)
else
:
return
tf
.
reduce_mean
(
accuracy_per_example
)
return
tf
.
reduce_mean
(
input_tensor
=
accuracy_per_example
)
research/attention_ocr/python/metrics_test.py
View file @
31ca3b97
...
...
@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase):
A session object that should be used as a context manager.
"""
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
local_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
local_variables_initializer
())
yield
sess
def
_fake_labels
(
self
):
...
...
@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase):
return
incorrect
def
test_sequence_accuracy_identical_samples
(
self
):
labels_tf
=
tf
.
convert_to_tensor
(
self
.
_fake_labels
())
labels_tf
=
tf
.
convert_to_tensor
(
value
=
self
.
_fake_labels
())
accuracy_tf
=
metrics
.
sequence_accuracy
(
labels_tf
,
labels_tf
,
self
.
rej_char
)
...
...
@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase):
def
test_sequence_accuracy_one_char_difference
(
self
):
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_tf
=
tf
.
convert_to_tensor
(
ground_truth_np
)
ground_truth_tf
=
tf
.
convert_to_tensor
(
value
=
ground_truth_np
)
prediction_tf
=
tf
.
convert_to_tensor
(
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
value
=
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
accuracy_tf
=
metrics
.
sequence_accuracy
(
prediction_tf
,
ground_truth_tf
,
self
.
rej_char
)
...
...
@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase):
def
test_char_accuracy_one_char_difference_with_padding
(
self
):
ground_truth_np
=
self
.
_fake_labels
()
ground_truth_tf
=
tf
.
convert_to_tensor
(
ground_truth_np
)
ground_truth_tf
=
tf
.
convert_to_tensor
(
value
=
ground_truth_np
)
prediction_tf
=
tf
.
convert_to_tensor
(
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
value
=
self
.
_incorrect_copy
(
ground_truth_np
,
bad_indexes
=
((
0
,
0
))))
accuracy_tf
=
metrics
.
char_accuracy
(
prediction_tf
,
ground_truth_tf
,
self
.
rej_char
)
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment