Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
412f4d2e
Commit
412f4d2e
authored
Dec 23, 2020
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Dec 23, 2020
Browse files
Add host training support for StandardEvaluator.
PiperOrigin-RevId: 348868176
parent
b7930ff9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
143 additions
and
15 deletions
+143
-15
orbit/standard_runner.py
orbit/standard_runner.py
+44
-11
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+45
-4
orbit/utils/loop_fns.py
orbit/utils/loop_fns.py
+54
-0
No files found.
orbit/standard_runner.py
View file @
412f4d2e
...
...
@@ -50,12 +50,12 @@ class StandardTrainerOptions:
"""Advanced options for `orbit.StandardTrainer`.
Attributes:
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tpu_summary_optimization: A boolean indicating whether to enable a
performance optimization for summaries in TPUs. Writing summaries
conditionally with outside compilation on TPUs can be extremely slow. If
...
...
@@ -63,8 +63,8 @@ class StandardTrainerOptions:
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
"""
use_tf_while_loop
:
bool
=
True
use_tf_function
:
bool
=
True
use_tf_while_loop
:
bool
=
True
use_tpu_summary_optimization
:
bool
=
False
...
...
@@ -215,14 +215,30 @@ class StandardEvaluatorOptions:
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
"""
use_tf_function
:
bool
=
True
use_tf_while_loop
:
bool
=
False
def
_create_eval_loop_fn
(
eval_step_fn
,
options
:
StandardEvaluatorOptions
):
def
_create_eval_loop_fn
(
eval_step_fn
,
has_state
:
bool
,
options
:
StandardEvaluatorOptions
):
"""Create evaluation loop function."""
if
options
.
use_tf_while_loop
:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if
has_state
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn_with_state
(
eval_step_fn
)
else
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn
(
eval_step_fn
)
loop_fn
=
tf
.
function
(
loop_fn
)
else
:
if
options
.
use_tf_function
:
eval_step_fn
=
tf
.
function
(
eval_step_fn
)
return
loop_fns
.
create_loop_fn
(
eval_step_fn
)
loop_fn
=
loop_fns
.
create_loop_fn
(
eval_step_fn
)
return
loop_fn
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
...
...
@@ -254,7 +270,12 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
`DistributedDataset`.
options: An `orbit.StandardEvaluatorOptions` instance.
"""
self
.
_eval_options
=
options
or
StandardEvaluatorOptions
()
options
=
options
or
StandardEvaluatorOptions
()
if
options
.
use_tf_while_loop
and
not
options
.
use_tf_function
:
raise
ValueError
(
"`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported"
)
self
.
_eval_options
=
options
self
.
_eval_dataset
=
eval_dataset
self
.
_eval_loop_fn
=
None
...
...
@@ -268,14 +289,26 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
Returns:
The output of `self.eval_end()`.
Raises:
ValueError: If `options.use_tf_while_loop` is `True` and `num_steps` is
unspecified.
"""
if
self
.
_eval_options
.
use_tf_while_loop
and
num_steps
==
-
1
:
raise
ValueError
(
"Looping until exhausted is not supported if "
"`options.use_tf_while_loop` is `True`"
)
outputs
=
self
.
eval_begin
()
# pylint: disable=assignment-from-no-return
has_state
=
outputs
is
not
None
if
self
.
_eval_loop_fn
is
None
:
self
.
_eval_loop_fn
=
_create_eval_loop_fn
(
self
.
eval_step
,
options
=
self
.
_eval_options
)
self
.
eval_step
,
has_state
=
has_state
,
options
=
self
.
_eval_options
)
eval_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_dataset
)
if
self
.
_eval_options
.
use_tf_while_loop
and
not
has_state
:
self
.
_eval_loop_fn
(
eval_iter
,
num_steps
)
else
:
outputs
=
self
.
_eval_loop_fn
(
eval_iter
,
num_steps
,
state
=
outputs
,
reduce_fn
=
self
.
eval_reduce
)
...
...
orbit/standard_runner_test.py
View file @
412f4d2e
...
...
@@ -14,6 +14,8 @@
# ==============================================================================
"""Tests for orbit.standard_runner."""
from
absl.testing
import
parameterized
from
orbit
import
standard_runner
from
orbit
import
utils
...
...
@@ -79,7 +81,36 @@ class TestEvaluator(standard_runner.StandardEvaluator):
return
self
.
global_step
.
numpy
()
class
StandardRunnerTest
(
tf
.
test
.
TestCase
):
class
TestEvaluatorWithOutputsAggregation
(
standard_runner
.
StandardEvaluator
):
"""A StandardEvaluator subclass for tests."""
def
__init__
(
self
,
options
=
None
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
self
.
strategy
.
distribute_datasets_from_function
(
lambda
_
:
tf
.
data
.
Dataset
.
range
(
10
))
super
().
__init__
(
eval_dataset
=
dataset
,
options
=
options
)
def
eval_begin
(
self
):
return
tf
.
constant
((
0.0
,))
def
eval_reduce
(
self
,
state
,
step_outputs
):
state
=
tf
.
concat
([
state
,
step_outputs
],
0
)
return
state
def
eval_step
(
self
,
iterator
):
def
replica_step
(
x
):
x
=
tf
.
cast
(
x
,
tf
.
float32
)
return
tf
.
reduce_sum
(
x
)
return
self
.
strategy
.
experimental_local_results
(
self
.
strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),)))
def
eval_end
(
self
,
outputs
):
return
tf
.
reduce_sum
(
outputs
)
class
StandardRunnerTest
(
parameterized
.
TestCase
):
def
test_default_trainer
(
self
):
trainer
=
TestTrainer
()
...
...
@@ -91,10 +122,20 @@ class StandardRunnerTest(tf.test.TestCase):
trainer
=
TestTrainer
(
options
)
self
.
assertEqual
(
trainer
.
train
(
tf
.
constant
(
10
)),
10
)
def
test_default_evaluator
(
self
):
evaluator
=
TestEvaluator
()
@
parameterized
.
named_parameters
((
"use_tf_while_loop"
,
True
),
(
""
,
False
))
def
test_default_evaluator
(
self
,
use_tf_while_loop
):
options
=
standard_runner
.
StandardEvaluatorOptions
(
use_tf_while_loop
=
use_tf_while_loop
)
evaluator
=
TestEvaluator
(
options
)
self
.
assertEqual
(
evaluator
.
evaluate
(
tf
.
constant
(
10
)),
10
)
@
parameterized
.
named_parameters
((
"use_tf_while_loop"
,
True
),
(
""
,
False
))
def
test_evaluator_with_outputs_aggregation
(
self
,
use_tf_while_loop
):
options
=
standard_runner
.
StandardEvaluatorOptions
(
use_tf_while_loop
=
use_tf_while_loop
)
evaluator
=
TestEvaluatorWithOutputsAggregation
(
options
)
self
.
assertEqual
(
evaluator
.
evaluate
(
tf
.
constant
(
10
)),
45
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tf
.
test
.
main
()
orbit/utils/loop_fns.py
View file @
412f4d2e
...
...
@@ -117,6 +117,60 @@ def create_tf_while_loop_fn(step_fn):
return
loop_fn
def
create_tf_while_loop_fn_with_state
(
step_fn
):
"""Creates a TF while loop function with state.
This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
to be accumulated over multiple iterations of the loop. Note that the
structure of the `state` cannot be changed across iterations.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator`, `num_steps`, `state` and
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
definition below for additional details.
"""
def
loop_fn_with_state
(
iterator
,
num_steps
,
state
,
reduce_fn
):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
state: An initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`.
"""
if
not
isinstance
(
num_steps
,
tf
.
Tensor
):
raise
ValueError
(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`."
)
for
_
in
tf
.
range
(
num_steps
):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf
.
autograph
.
experimental
.
set_loop_options
(
shape_invariants
=
[(
t
,
tf
.
TensorShape
([
None
]
*
t
.
shape
.
rank
))
for
t
in
tf
.
nest
.
flatten
(
state
)
if
tf
.
is_tensor
(
t
)])
outputs
=
step_fn
(
iterator
)
state
=
reduce_fn
(
state
,
outputs
)
return
state
return
loop_fn_with_state
class
LoopFnWithSummaries
(
tpu_summaries
.
OptionalSummariesFunction
):
"""Implements a two-program approach for optimizing summaries on TPU.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment