Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
71d2680d
Commit
71d2680d
authored
May 01, 2020
by
Yeqing Li
Committed by
A. Unique TensorFlower
May 01, 2020
Browse files
Internal change
PiperOrigin-RevId: 309486898
parent
10b38209
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
36 deletions
+52
-36
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+52
-36
No files found.
official/modeling/training/distributed_executor.py
View file @
71d2680d
...
@@ -134,26 +134,27 @@ class SummaryWriter(object):
...
@@ -134,26 +134,27 @@ class SummaryWriter(object):
class
DistributedExecutor
(
object
):
class
DistributedExecutor
(
object
):
"""Interface to train and eval models with tf.distribute.Strategy.
"""Interface to train and eval models with tf.distribute.Strategy.
"""
Arguments:
def
__init__
(
self
,
strategy
,
params
,
model_fn
,
loss_fn
,
is_multi_host
=
False
):
"""Constructor.
Args:
strategy: an instance of tf.distribute.Strategy.
strategy: an instance of tf.distribute.Strategy.
params: Model configuration needed to run distribution strategy.
params: Model configuration needed to run distribution strategy.
model_fn: Keras model function. Signature:
model_fn: Keras model function. Signature:
(params: ParamsDict) -> tf.keras.models.Model.
(params: ParamsDict) -> tf.keras.models.Model.
loss_fn: loss function. Signature:
loss_fn: loss function. Signature:
(y_true: Tensor, y_pred: Tensor) -> Tensor
(y_true: Tensor, y_pred: Tensor) -> Tensor
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
is_multi_host: Set to True when using multi hosts for training, like multi
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
worker GPU or TPU pod (slice). Otherwise, False.
"""
"""
def
__init__
(
self
,
strategy
,
params
,
model_fn
,
loss_fn
,
is_multi_host
=
False
):
self
.
_params
=
params
self
.
_params
=
params
self
.
_model_fn
=
model_fn
self
.
_model_fn
=
model_fn
self
.
_loss_fn
=
loss_fn
self
.
_loss_fn
=
loss_fn
...
@@ -224,6 +225,18 @@ class DistributedExecutor(object):
...
@@ -224,6 +225,18 @@ class DistributedExecutor(object):
loss_fn
,
loss_fn
,
optimizer
,
optimizer
,
metric
=
None
):
metric
=
None
):
"""Creates a single training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
metrics
=
metrics_as_dict
(
metric
)
metrics
=
metrics_as_dict
(
metric
)
def
_replicated_step
(
inputs
):
def
_replicated_step
(
inputs
):
...
@@ -257,13 +270,12 @@ class DistributedExecutor(object):
...
@@ -257,13 +270,12 @@ class DistributedExecutor(object):
model: (Tensor, bool) -> Tensor. model function.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
optimizer: tf.keras.optimizers.Optimizer.
iterator: an iterator that yields input tensors.
metric: tf.keras.metrics.Metric subclass.
metric: tf.keras.metrics.Metric subclass.
Returns:
Returns:
The training step callable.
The training step callable.
"""
"""
_
replicated_step
=
self
.
_create_replicated_step
(
strategy
,
model
,
loss_fn
,
replicated_step
=
self
.
_create_replicated_step
(
strategy
,
model
,
loss_fn
,
optimizer
,
metric
)
optimizer
,
metric
)
@
tf
.
function
@
tf
.
function
...
@@ -282,10 +294,10 @@ class DistributedExecutor(object):
...
@@ -282,10 +294,10 @@ class DistributedExecutor(object):
'retracing.'
)
'retracing.'
)
per_replica_losses
=
strategy
.
run
(
per_replica_losses
=
strategy
.
run
(
_
replicated_step
,
args
=
(
next
(
iterator
),))
replicated_step
,
args
=
(
next
(
iterator
),))
for
_
in
tf
.
range
(
num_steps
-
1
):
for
_
in
tf
.
range
(
num_steps
-
1
):
per_replica_losses
=
strategy
.
run
(
per_replica_losses
=
strategy
.
run
(
_
replicated_step
,
args
=
(
next
(
iterator
),))
replicated_step
,
args
=
(
next
(
iterator
),))
# For reporting, we returns the mean of losses.
# For reporting, we returns the mean of losses.
losses
=
tf
.
nest
.
map_structure
(
losses
=
tf
.
nest
.
map_structure
(
...
@@ -318,7 +330,6 @@ class DistributedExecutor(object):
...
@@ -318,7 +330,6 @@ class DistributedExecutor(object):
return
test_step
return
test_step
def
train
(
self
,
def
train
(
self
,
train_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
tf
.
data
.
Dataset
],
train_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
tf
.
data
.
Dataset
],
eval_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
eval_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
...
@@ -404,6 +415,7 @@ class DistributedExecutor(object):
...
@@ -404,6 +415,7 @@ class DistributedExecutor(object):
train_iterator
=
self
.
_get_input_iterator
(
train_input_fn
,
strategy
)
train_iterator
=
self
.
_get_input_iterator
(
train_input_fn
,
strategy
)
train_loss
=
None
train_loss
=
None
eval_metric_result
=
None
eval_metric_result
=
None
tf
.
keras
.
backend
.
set_learning_phase
(
1
)
with
strategy
.
scope
():
with
strategy
.
scope
():
# To correctly place the model weights on accelerators,
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
# model and optimizer should be created in scope.
...
@@ -584,10 +596,10 @@ class DistributedExecutor(object):
...
@@ -584,10 +596,10 @@ class DistributedExecutor(object):
"""Runs distributed evaluation on model folder.
"""Runs distributed evaluation on model folder.
Args:
Args:
model_dir: the folder for storing model checkpoints.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
eval_metric_fn: metric_fn for evaluation in test_step.
model_dir: the folder for storing model checkpoints.
total_steps: total training steps. If the current step reaches the
total_steps: total training steps. If the current step reaches the
total_steps, the evaluation loop will stop.
total_steps, the evaluation loop will stop.
eval_timeout: The maximum number of seconds to wait between checkpoints.
eval_timeout: The maximum number of seconds to wait between checkpoints.
...
@@ -638,11 +650,11 @@ class DistributedExecutor(object):
...
@@ -638,11 +650,11 @@ class DistributedExecutor(object):
"""Runs distributed evaluation on the one checkpoint.
"""Runs distributed evaluation on the one checkpoint.
Args:
Args:
checkpoint_path: the checkpoint to evaluate.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
eval_metric_fn: metric_fn for evaluation in test_step.
checkpoint_path: the checkpoint to evaluate.
summary_writer: function to create summary writer.
summary_writer_fn: function to create summary writer.
Returns:
Returns:
Eval metrics dictionary of the last checkpoint.
Eval metrics dictionary of the last checkpoint.
...
@@ -651,6 +663,8 @@ class DistributedExecutor(object):
...
@@ -651,6 +663,8 @@ class DistributedExecutor(object):
raise
ValueError
(
'if `eval_metric_fn` is specified, '
raise
ValueError
(
'if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.'
)
'eval_metric_fn must be a callable.'
)
old_phrase
=
tf
.
keras
.
backend
.
learning_phase
()
tf
.
keras
.
backend
.
set_learning_phase
(
0
)
params
=
self
.
_params
params
=
self
.
_params
strategy
=
self
.
_strategy
strategy
=
self
.
_strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# To reduce unnecessary send/receive input pipeline operation, we place
...
@@ -686,6 +700,7 @@ class DistributedExecutor(object):
...
@@ -686,6 +700,7 @@ class DistributedExecutor(object):
summary_writer
(
metrics
=
eval_metric_result
,
step
=
current_step
)
summary_writer
(
metrics
=
eval_metric_result
,
step
=
current_step
)
reset_states
(
eval_metric
)
reset_states
(
eval_metric
)
tf
.
keras
.
backend
.
set_learning_phase
(
old_phrase
)
return
eval_metric_result
,
current_step
return
eval_metric_result
,
current_step
def
predict
(
self
):
def
predict
(
self
):
...
@@ -726,18 +741,20 @@ class ExecutorBuilder(object):
...
@@ -726,18 +741,20 @@ class ExecutorBuilder(object):
model_fn=my_model_fn,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
metric_fn=my_metric_fn)
"""
def
__init__
(
self
,
strategy_type
=
None
,
strategy_config
=
None
):
_
=
distribution_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
"""Constructor.
Args:
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
If
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
None. User is responsible to set the strategy before calling
If
None. User is responsible to set the strategy before calling
build_executor(...).
build_executor(...).
strategy_config: necessary config for constructing the proper Strategy.
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
Check strategy_flags_dict() for examples of the structure.
"""
"""
def
__init__
(
self
,
strategy_type
=
None
,
strategy_config
=
None
):
_
=
distribution_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
self
.
_strategy
=
distribution_utils
.
get_distribution_strategy
(
self
.
_strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
strategy_type
,
distribution_strategy
=
strategy_type
,
num_gpus
=
strategy_config
.
num_gpus
,
num_gpus
=
strategy_config
.
num_gpus
,
...
@@ -755,7 +772,6 @@ class ExecutorBuilder(object):
...
@@ -755,7 +772,6 @@ class ExecutorBuilder(object):
"""Sets default summary writer for the current thread."""
"""Sets default summary writer for the current thread."""
self
.
_strategy
=
new_strategy
self
.
_strategy
=
new_strategy
def
build_executor
(
self
,
def
build_executor
(
self
,
class_ctor
=
DistributedExecutor
,
class_ctor
=
DistributedExecutor
,
params
=
None
,
params
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment