Commit 71d2680d authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 309486898
parent 10b38209
......@@ -134,26 +134,27 @@ class SummaryWriter(object):
class DistributedExecutor(object):
"""Interface to train and eval models with tf.distribute.Strategy.
"""
Arguments:
def __init__(self,
strategy,
params,
model_fn,
loss_fn,
is_multi_host=False):
"""Constructor.
Args:
strategy: an instance of tf.distribute.Strategy.
params: Model configuration needed to run distribution strategy.
model_fn: Keras model function. Signature:
(params: ParamsDict) -> tf.keras.models.Model.
loss_fn: loss function. Signature:
(y_true: Tensor, y_pred: Tensor) -> Tensor
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
"""
def __init__(self,
strategy,
params,
model_fn,
loss_fn,
is_multi_host=False):
self._params = params
self._model_fn = model_fn
self._loss_fn = loss_fn
......@@ -224,6 +225,18 @@ class DistributedExecutor(object):
loss_fn,
optimizer,
metric=None):
"""Creates a single training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
metrics = metrics_as_dict(metric)
def _replicated_step(inputs):
......@@ -257,13 +270,12 @@ class DistributedExecutor(object):
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
iterator: an iterator that yields input tensors.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
_replicated_step = self._create_replicated_step(strategy, model, loss_fn,
replicated_step = self._create_replicated_step(strategy, model, loss_fn,
optimizer, metric)
@tf.function
......@@ -282,10 +294,10 @@ class DistributedExecutor(object):
'retracing.')
per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),))
replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),))
replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
losses = tf.nest.map_structure(
......@@ -318,7 +330,6 @@ class DistributedExecutor(object):
return test_step
def train(self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Callable[[params_dict.ParamsDict],
......@@ -404,6 +415,7 @@ class DistributedExecutor(object):
train_iterator = self._get_input_iterator(train_input_fn, strategy)
train_loss = None
eval_metric_result = None
tf.keras.backend.set_learning_phase(1)
with strategy.scope():
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
......@@ -584,10 +596,10 @@ class DistributedExecutor(object):
"""Runs distributed evaluation on model folder.
Args:
model_dir: the folder for storing model checkpoints.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
model_dir: the folder for storing model checkpoints.
total_steps: total training steps. If the current step reaches the
total_steps, the evaluation loop will stop.
eval_timeout: The maximum number of seconds to wait between checkpoints.
......@@ -638,11 +650,11 @@ class DistributedExecutor(object):
"""Runs distributed evaluation on the one checkpoint.
Args:
checkpoint_path: the checkpoint to evaluate.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
checkpoint_path: the checkpoint to evaluate.
summary_writer_fn: function to create summary writer.
summary_writer: function to create summary writer.
Returns:
Eval metrics dictionary of the last checkpoint.
......@@ -651,6 +663,8 @@ class DistributedExecutor(object):
raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.')
old_phrase = tf.keras.backend.learning_phase()
tf.keras.backend.set_learning_phase(0)
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
......@@ -686,6 +700,7 @@ class DistributedExecutor(object):
summary_writer(metrics=eval_metric_result, step=current_step)
reset_states(eval_metric)
tf.keras.backend.set_learning_phase(old_phrase)
return eval_metric_result, current_step
def predict(self):
......@@ -726,18 +741,20 @@ class ExecutorBuilder(object):
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
"""
def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster(
strategy_config.worker_hosts, strategy_config.task_index)
"""Constructor.
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. If
None. User is responsible to set the strategy before calling
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
If None. User is responsible to set the strategy before calling
build_executor(...).
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
"""
def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster(
strategy_config.worker_hosts, strategy_config.task_index)
self._strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus,
......@@ -755,7 +772,6 @@ class ExecutorBuilder(object):
"""Sets default summary writer for the current thread."""
self._strategy = new_strategy
def build_executor(self,
class_ctor=DistributedExecutor,
params=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment