Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2e77bb3e
Commit
2e77bb3e
authored
Mar 13, 2021
by
Yeqing Li
Committed by
A. Unique TensorFlower
Mar 13, 2021
Browse files
Adds tf.distribute.experimental.ParameterServerStrategy support to Orbit.
PiperOrigin-RevId: 362729857
parent
f7ea371e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
294 additions
and
45 deletions
+294
-45
official/core/base_trainer.py
official/core/base_trainer.py
+85
-7
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+163
-0
orbit/standard_runner.py
orbit/standard_runner.py
+46
-38
No files found.
official/core/base_trainer.py
View file @
2e77bb3e
...
@@ -18,7 +18,7 @@ The base trainer implements the Orbit `StandardTrainable` and
...
@@ -18,7 +18,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
interchangable and independent on model architectures and tasks.
"""
"""
import
functools
from
absl
import
logging
from
absl
import
logging
import
gin
import
gin
import
orbit
import
orbit
...
@@ -84,10 +84,85 @@ class Recovery:
...
@@ -84,10 +84,85 @@ class Recovery:
"%f at step %d."
,
checkpoint_path
,
loss_value
,
global_step
)
"%f at step %d."
,
checkpoint_path
,
loss_value
,
global_step
)
class
_AsyncTrainer
(
orbit
.
StandardTrainer
,
orbit
.
StandardEvaluator
):
"""Trainer class for both sync and async Strategy."""
def
init_async
(
self
):
"""Initializes the Async Trainer base class."""
assert
isinstance
(
self
.
_strategy
,
tf
.
distribute
.
Strategy
)
self
.
_is_async
=
isinstance
(
self
.
_strategy
,
tf
.
distribute
.
experimental
.
ParameterServerStrategy
)
self
.
_coordinator
=
None
if
self
.
_is_async
:
self
.
_coordinator
=
(
tf
.
distribute
.
experimental
.
coordinator
.
ClusterCoordinator
(
self
.
_strategy
))
def
join
(
self
):
"""Join all async steps. Only useful in aysnc training."""
if
getattr
(
self
,
"_is_async"
,
False
):
self
.
_coordinator
.
join
()
def
create_train_loop_fn
(
self
):
"""Creates a eval loop from the given step function and options."""
train_loop_fn
=
super
().
create_train_loop_fn
()
if
getattr
(
self
,
"_is_async"
,
False
):
def
_async_loop_fn
(
iterator
,
num_steps
):
self
.
_coordinator
.
schedule
(
train_loop_fn
,
args
=
(
iterator
,
num_steps
))
return
_async_loop_fn
else
:
return
train_loop_fn
def
create_eval_loop_fn
(
self
,
has_state
:
bool
):
"""Creates a training loop from the given step function and options."""
eval_loop_fn
=
super
().
create_eval_loop_fn
(
has_state
)
if
getattr
(
self
,
"_is_async"
,
False
):
if
has_state
:
raise
ValueError
(
"Stateful eval loop is not supported in async training."
)
def
_async_loop_fn
(
iterator
,
num_steps
,
state
=
None
,
reduce_fn
=
None
):
assert
state
is
None
assert
reduce_fn
is
None
self
.
_coordinator
.
schedule
(
eval_loop_fn
,
args
=
(
iterator
,
num_steps
))
return
_async_loop_fn
else
:
return
eval_loop_fn
def
distribute_dataset
(
self
,
dataset_or_fn
,
*
args
,
**
kwargs
):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally
have an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
if
getattr
(
self
,
"_is_async"
,
False
):
per_worker_dataset_fn
=
functools
.
partial
(
orbit
.
utils
.
make_distributed_dataset
,
self
.
_strategy
,
dataset_or_fn
,
*
args
,
**
kwargs
)
per_worker_dataset_fn
=
tf
.
function
(
per_worker_dataset_fn
)
return
self
.
_coordinator
.
create_per_worker_dataset
(
per_worker_dataset_fn
)
else
:
return
orbit
.
utils
.
make_distributed_dataset
(
self
.
_strategy
,
dataset_or_fn
,
*
args
,
**
kwargs
)
@
gin
.
configurable
@
gin
.
configurable
class
Trainer
(
orbit
.
StandardTrainer
,
orbit
.
StandardEvaluato
r
):
class
Trainer
(
_AsyncTraine
r
):
"""Implements the common trainer shared for TensorFlow models."""
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def
__init__
(
self
,
def
__init__
(
self
,
config
:
ExperimentConfig
,
config
:
ExperimentConfig
,
task
:
base_task
.
Task
,
task
:
base_task
.
Task
,
...
@@ -147,9 +222,11 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -147,9 +222,11 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self
.
_validation_metrics
=
self
.
task
.
build_metrics
(
self
.
_validation_metrics
=
self
.
task
.
build_metrics
(
training
=
False
)
+
self
.
model
.
metrics
training
=
False
)
+
self
.
model
.
metrics
self
.
init_async
()
if
train
:
if
train
:
train_dataset
=
orbit
.
utils
.
make_
distribute
d
_dataset
(
train_dataset
=
self
.
distribute_dataset
(
self
.
strategy
,
self
.
task
.
build_inputs
,
self
.
config
.
task
.
train_data
)
self
.
task
.
build_inputs
,
self
.
config
.
task
.
train_data
)
orbit
.
StandardTrainer
.
__init__
(
orbit
.
StandardTrainer
.
__init__
(
self
,
self
,
train_dataset
,
train_dataset
,
...
@@ -159,9 +236,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -159,9 +236,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
use_tpu_summary_optimization
=
config
.
trainer
.
allow_tpu_summary
))
use_tpu_summary_optimization
=
config
.
trainer
.
allow_tpu_summary
))
if
evaluate
:
if
evaluate
:
eval_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
eval_dataset
=
self
.
distribute_dataset
(
self
.
strategy
,
self
.
task
.
build_inputs
,
self
.
task
.
build_inputs
,
self
.
config
.
task
.
validation_data
)
self
.
config
.
task
.
validation_data
)
orbit
.
StandardEvaluator
.
__init__
(
orbit
.
StandardEvaluator
.
__init__
(
self
,
self
,
eval_dataset
,
eval_dataset
,
...
@@ -270,6 +346,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -270,6 +346,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def
train_loop_end
(
self
):
def
train_loop_end
(
self
):
"""See base class."""
"""See base class."""
self
.
join
()
# Checks if the model numeric status is stable and conducts the checkpoint
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
# recovery accordingly.
if
self
.
_recovery
:
if
self
.
_recovery
:
...
@@ -324,6 +401,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -324,6 +401,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def
eval_end
(
self
,
aggregated_logs
=
None
):
def
eval_end
(
self
,
aggregated_logs
=
None
):
"""Processes evaluation results."""
"""Processes evaluation results."""
self
.
join
()
logs
=
{}
logs
=
{}
for
metric
in
self
.
validation_metrics
:
for
metric
in
self
.
validation_metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
logs
[
metric
.
name
]
=
metric
.
result
()
...
...
official/core/base_trainer_test.py
View file @
2e77bb3e
...
@@ -14,9 +14,13 @@
...
@@ -14,9 +14,13 @@
"""Tests for tensorflow_models.core.trainers.trainer."""
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
# pylint: disable=g-direct-tensorflow-import
import
multiprocessing
import
os
import
os
import
sys
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
portpicker
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
...
@@ -26,6 +30,9 @@ from official.core import config_definitions as cfg
...
@@ -26,6 +30,9 @@ from official.core import config_definitions as cfg
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.utils.testing
import
mock_task
from
official.utils.testing
import
mock_task
TPU_TEST
=
'test_tpu'
in
sys
.
argv
[
0
]
GPU_TEST
=
'test_gpu'
in
sys
.
argv
[
0
]
def
all_strategy_combinations
():
def
all_strategy_combinations
():
return
combinations
.
combine
(
return
combinations
.
combine
(
...
@@ -36,6 +43,113 @@ def all_strategy_combinations():
...
@@ -36,6 +43,113 @@ def all_strategy_combinations():
],)
],)
def
create_in_process_cluster
(
num_workers
,
num_ps
):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports
=
[
portpicker
.
pick_unused_port
()
for
_
in
range
(
num_workers
)]
ps_ports
=
[
portpicker
.
pick_unused_port
()
for
_
in
range
(
num_ps
)]
cluster_dict
=
{}
cluster_dict
[
'worker'
]
=
[
'localhost:%s'
%
port
for
port
in
worker_ports
]
if
num_ps
>
0
:
cluster_dict
[
'ps'
]
=
[
'localhost:%s'
%
port
for
port
in
ps_ports
]
cluster_spec
=
tf
.
train
.
ClusterSpec
(
cluster_dict
)
# Workers need some inter_ops threads to work properly.
worker_config
=
tf
.
compat
.
v1
.
ConfigProto
()
if
multiprocessing
.
cpu_count
()
<
num_workers
+
1
:
worker_config
.
inter_op_parallelism_threads
=
num_workers
+
1
for
i
in
range
(
num_workers
):
tf
.
distribute
.
Server
(
cluster_spec
,
job_name
=
'worker'
,
task_index
=
i
,
config
=
worker_config
,
protocol
=
'grpc'
)
for
i
in
range
(
num_ps
):
tf
.
distribute
.
Server
(
cluster_spec
,
job_name
=
'ps'
,
task_index
=
i
,
protocol
=
'grpc'
)
cluster_resolver
=
tf
.
distribute
.
cluster_resolver
.
SimpleClusterResolver
(
cluster_spec
,
rpc_layer
=
'grpc'
)
return
cluster_resolver
def
dataset_fn
(
input_context
=
None
):
del
input_context
def
dummy_data
(
_
):
return
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
float32
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
class
MockAsyncTrainer
(
trainer_lib
.
_AsyncTrainer
):
"""Mock AsyncTrainer to test the _AsyncTrainer class."""
def
__init__
(
self
):
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
init_async
()
self
.
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
trainable
=
False
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
self
.
eval_global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'eval_global_step'
,
trainable
=
False
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
train_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
trainer_lib
.
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
trainer_lib
.
orbit
.
StandardTrainerOptions
())
eval_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
trainer_lib
.
orbit
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
,
options
=
trainer_lib
.
orbit
.
StandardEvaluatorOptions
(
use_tf_while_loop
=
True
))
def
train_loop_begin
(
self
):
self
.
global_step
.
assign
(
0
)
def
train_step
(
self
,
iterator
):
def
replica_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
_strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
train_loop_end
(
self
):
self
.
join
()
return
self
.
global_step
.
numpy
()
def
eval_begin
(
self
):
self
.
eval_global_step
.
assign
(
0
)
def
eval_step
(
self
,
iterator
):
def
replica_step
(
_
):
self
.
eval_global_step
.
assign_add
(
1
)
self
.
_strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
eval_end
(
self
):
self
.
join
()
return
self
.
eval_global_step
.
numpy
()
class
TrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
TrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -71,6 +185,55 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -71,6 +185,55 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
def
test_base_async_trainer
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
num_workers
=
3
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
distribution
=
tf
.
distribute
.
experimental
.
ParameterServerStrategy
(
cluster_resolver
)
with
distribution
.
scope
():
trainer
=
MockAsyncTrainer
()
trainer
.
init_async
()
self
.
assertIsInstance
(
trainer
.
_coordinator
,
tf
.
distribute
.
experimental
.
coordinator
.
ClusterCoordinator
)
self
.
assertEqual
(
trainer
.
train
(
tf
.
constant
(
10
)),
10
)
self
.
assertEqual
(
trainer
.
evaluate
(
tf
.
constant
(
11
)),
11
)
def
test_async_trainer_train
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
num_workers
=
3
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
distribution
=
tf
.
distribute
.
experimental
.
ParameterServerStrategy
(
cluster_resolver
)
with
distribution
.
scope
():
config
=
cfg
.
ExperimentConfig
(
**
self
.
_config
.
as_dict
())
config
.
trainer
.
eval_tf_while_loop
=
True
trainer
=
self
.
create_test_trainer
(
config
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
def
test_async_trainer_validate
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
num_workers
=
3
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
distribution
=
tf
.
distribute
.
experimental
.
ParameterServerStrategy
(
cluster_resolver
)
with
distribution
.
scope
():
config
=
cfg
.
ExperimentConfig
(
**
self
.
_config
.
as_dict
())
config
.
trainer
.
eval_tf_while_loop
=
True
trainer
=
self
.
create_test_trainer
(
config
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'acc'
,
logs
)
self
.
assertIn
(
'validation_loss'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_validate
(
self
,
distribution
):
def
test_trainer_validate
(
self
,
distribution
):
with
distribution
.
scope
():
with
distribution
.
scope
():
...
...
orbit/standard_runner.py
View file @
2e77bb3e
...
@@ -68,21 +68,6 @@ class StandardTrainerOptions:
...
@@ -68,21 +68,6 @@ class StandardTrainerOptions:
use_tpu_summary_optimization
:
bool
=
False
use_tpu_summary_optimization
:
bool
=
False
def
_create_train_loop_fn
(
train_step_fn
,
options
:
StandardTrainerOptions
):
"""Creates a training loop from the given step function and options."""
if
options
.
use_tf_while_loop
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn
(
train_step_fn
)
if
options
.
use_tpu_summary_optimization
:
loop_fn
=
loop_fns
.
LoopFnWithSummaries
(
loop_fn
)
else
:
loop_fn
=
tf
.
function
(
loop_fn
)
else
:
if
options
.
use_tf_function
:
train_step_fn
=
tf
.
function
(
train_step_fn
)
loop_fn
=
loop_fns
.
create_loop_fn
(
train_step_fn
)
return
loop_fn
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
"""Implements standard functionality on top of the AbstractTrainer API.
"""Implements standard functionality on top of the AbstractTrainer API.
...
@@ -119,6 +104,25 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
...
@@ -119,6 +104,25 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self
.
_train_iter
=
None
self
.
_train_iter
=
None
self
.
_train_loop_fn
=
None
self
.
_train_loop_fn
=
None
def
create_train_loop_fn
(
self
):
"""Creates a training loop from the current step function and options.
Returns:
The train loop function, i.e. wrapper of multiple train steps.
"""
train_step_fn
=
self
.
train_step
if
self
.
_train_options
.
use_tf_while_loop
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn
(
train_step_fn
)
if
self
.
_train_options
.
use_tpu_summary_optimization
:
loop_fn
=
loop_fns
.
LoopFnWithSummaries
(
loop_fn
)
else
:
loop_fn
=
tf
.
function
(
loop_fn
)
else
:
if
self
.
_train_options
.
use_tf_function
:
train_step_fn
=
tf
.
function
(
train_step_fn
)
loop_fn
=
loop_fns
.
create_loop_fn
(
train_step_fn
)
return
loop_fn
def
train
(
self
,
num_steps
:
tf
.
Tensor
)
->
Optional
[
runner
.
Output
]:
def
train
(
self
,
num_steps
:
tf
.
Tensor
)
->
Optional
[
runner
.
Output
]:
"""Implements `num_steps` steps of training.
"""Implements `num_steps` steps of training.
...
@@ -132,8 +136,7 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
...
@@ -132,8 +136,7 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self
.
train_loop_begin
()
self
.
train_loop_begin
()
if
self
.
_train_loop_fn
is
None
:
if
self
.
_train_loop_fn
is
None
:
self
.
_train_loop_fn
=
_create_train_loop_fn
(
self
.
_train_loop_fn
=
self
.
create_train_loop_fn
()
self
.
train_step
,
options
=
self
.
_train_options
)
if
self
.
_train_iter
is
None
:
if
self
.
_train_iter
is
None
:
self
.
_train_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
train_dataset
)
self
.
_train_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
train_dataset
)
...
@@ -222,25 +225,6 @@ class StandardEvaluatorOptions:
...
@@ -222,25 +225,6 @@ class StandardEvaluatorOptions:
use_tf_while_loop
:
bool
=
False
use_tf_while_loop
:
bool
=
False
def
_create_eval_loop_fn
(
eval_step_fn
,
has_state
:
bool
,
options
:
StandardEvaluatorOptions
):
"""Create evaluation loop function."""
if
options
.
use_tf_while_loop
:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if
has_state
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn_with_state
(
eval_step_fn
)
else
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn
(
eval_step_fn
)
loop_fn
=
tf
.
function
(
loop_fn
)
else
:
if
options
.
use_tf_function
:
eval_step_fn
=
tf
.
function
(
eval_step_fn
)
loop_fn
=
loop_fns
.
create_loop_fn
(
eval_step_fn
)
return
loop_fn
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractEvaluator APIs.
"""Implements the standard functionality of AbstractEvaluator APIs.
...
@@ -279,6 +263,31 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
...
@@ -279,6 +263,31 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
self
.
_eval_dataset
=
eval_dataset
self
.
_eval_dataset
=
eval_dataset
self
.
_eval_loop_fn
=
None
self
.
_eval_loop_fn
=
None
def
create_eval_loop_fn
(
self
,
has_state
:
bool
):
"""Creates an eval loop from the current step function and options.
Args:
has_state: If the step function has state, state will be kept in the loop.
Returns:
The eval loop function, i.e. wrapper of multiple eval steps.
"""
eval_step_fn
=
self
.
eval_step
if
self
.
_eval_options
.
use_tf_while_loop
:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if
has_state
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn_with_state
(
eval_step_fn
)
else
:
loop_fn
=
loop_fns
.
create_tf_while_loop_fn
(
eval_step_fn
)
loop_fn
=
tf
.
function
(
loop_fn
)
else
:
if
self
.
_eval_options
.
use_tf_function
:
eval_step_fn
=
tf
.
function
(
eval_step_fn
)
loop_fn
=
loop_fns
.
create_loop_fn
(
eval_step_fn
)
return
loop_fn
def
evaluate
(
self
,
num_steps
:
tf
.
Tensor
)
->
Optional
[
runner
.
Output
]:
def
evaluate
(
self
,
num_steps
:
tf
.
Tensor
)
->
Optional
[
runner
.
Output
]:
"""Implements `num_steps` steps of evaluation.
"""Implements `num_steps` steps of evaluation.
...
@@ -302,8 +311,7 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
...
@@ -302,8 +311,7 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
has_state
=
outputs
is
not
None
has_state
=
outputs
is
not
None
if
self
.
_eval_loop_fn
is
None
:
if
self
.
_eval_loop_fn
is
None
:
self
.
_eval_loop_fn
=
_create_eval_loop_fn
(
self
.
_eval_loop_fn
=
self
.
create_eval_loop_fn
(
has_state
)
self
.
eval_step
,
has_state
=
has_state
,
options
=
self
.
_eval_options
)
eval_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_dataset
)
eval_iter
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_dataset
)
if
self
.
_eval_options
.
use_tf_while_loop
and
not
has_state
:
if
self
.
_eval_options
.
use_tf_while_loop
and
not
has_state
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment