Commit 2d7223ad authored by Ken Franko's avatar Ken Franko Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 301277545
parent 64a48653
...@@ -321,7 +321,7 @@ def run_customized_training_loop( ...@@ -321,7 +321,7 @@ def run_customized_training_loop(
'retracing.') 'retracing.')
for _ in tf.range(steps): for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) strategy.run(_replicated_step, args=(next(iterator),))
def train_single_step(iterator): def train_single_step(iterator):
"""Performs a distributed training step. """Performs a distributed training step.
...@@ -332,7 +332,7 @@ def run_customized_training_loop( ...@@ -332,7 +332,7 @@ def run_customized_training_loop(
Raises: Raises:
ValueError: Any of the arguments or tensor shapes are invalid. ValueError: Any of the arguments or tensor shapes are invalid.
""" """
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) strategy.run(_replicated_step, args=(next(iterator),))
def test_step(iterator): def test_step(iterator):
"""Calculates evaluation metrics on distributed devices.""" """Calculates evaluation metrics on distributed devices."""
...@@ -345,7 +345,7 @@ def run_customized_training_loop( ...@@ -345,7 +345,7 @@ def run_customized_training_loop(
for metric in eval_metrics: for metric in eval_metrics:
metric.update_state(labels, model_outputs) metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),)) strategy.run(_test_step_fn, args=(next(iterator),))
if not run_eagerly: if not run_eagerly:
train_single_step = tf.function(train_single_step) train_single_step = tf.function(train_single_step)
......
...@@ -243,10 +243,10 @@ class DistributedExecutor(object): ...@@ -243,10 +243,10 @@ class DistributedExecutor(object):
raise ValueError('steps should be an Tensor. Python object may cause ' raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.') 'retracing.')
per_replica_losses = strategy.experimental_run_v2( per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),)) _replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1): for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.experimental_run_v2( per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),)) _replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses. # For reporting, we returns the mean of losses.
...@@ -278,7 +278,7 @@ class DistributedExecutor(object): ...@@ -278,7 +278,7 @@ class DistributedExecutor(object):
metric.update_state(labels, model_outputs) metric.update_state(labels, model_outputs)
return labels, model_outputs return labels, model_outputs
return strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),)) return strategy.run(_test_step_fn, args=(next(iterator),))
return test_step return test_step
......
...@@ -267,7 +267,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn, ...@@ -267,7 +267,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
model_outputs = trained_model(inputs, training=False) model_outputs = trained_model(inputs, training=False)
return model_outputs, labels return model_outputs, labels
outputs, labels = strategy.experimental_run_v2( outputs, labels = strategy.run(
_test_step_fn, args=(next(iterator),)) _test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits # outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results, outputs = tf.nest.map_structure(strategy.experimental_local_results,
......
...@@ -194,7 +194,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -194,7 +194,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
start_logits=start_logits, start_logits=start_logits,
end_logits=end_logits) end_logits=end_logits)
outputs = strategy.experimental_run_v2( outputs = strategy.run(
_replicated_step, args=(next(iterator),)) _replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs) return tf.nest.map_structure(strategy.experimental_local_results, outputs)
......
...@@ -280,7 +280,7 @@ class TransformerTask(object): ...@@ -280,7 +280,7 @@ class TransformerTask(object):
for _ in tf.range(steps): for _ in tf.range(steps):
train_loss_metric.reset_states() train_loss_metric.reset_states()
self.distribution_strategy.experimental_run_v2( self.distribution_strategy.run(
_step_fn, args=(next(iterator),)) _step_fn, args=(next(iterator),))
cased_score, uncased_score = None, None cased_score, uncased_score = None, None
......
...@@ -132,7 +132,7 @@ def translate_file(model, ...@@ -132,7 +132,7 @@ def translate_file(model,
val_outputs, _ = model([val_inputs], training=False) val_outputs, _ = model([val_inputs], training=False)
return tag, val_outputs return tag, val_outputs
return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,)) return distribution_strategy.run(_step_fn, args=(inputs,))
translations = [] translations = []
if distribution_strategy: if distribution_strategy:
...@@ -151,7 +151,7 @@ def translate_file(model, ...@@ -151,7 +151,7 @@ def translate_file(model,
replica_id = replica_context.replica_id_in_sync_group replica_id = replica_context.replica_id_in_sync_group
return replica_id, text[replica_id] return replica_id, text[replica_id]
text = distribution_strategy.experimental_run_v2(text_as_per_replica) text = distribution_strategy.run(text_as_per_replica)
outputs = distribution_strategy.experimental_local_results( outputs = distribution_strategy.experimental_local_results(
predict_step(text)) predict_step(text))
tags, unordered_val_outputs = outputs[0] tags, unordered_val_outputs = outputs[0]
......
...@@ -87,7 +87,7 @@ def run_evaluation(strategy, ...@@ -87,7 +87,7 @@ def run_evaluation(strategy,
@tf.function @tf.function
def _run_evaluation(test_iterator): def _run_evaluation(test_iterator):
"""Runs validation steps.""" """Runs validation steps."""
logits, labels, masks = strategy.experimental_run_v2( logits, labels, masks = strategy.run(
_test_step_fn, args=(next(test_iterator),)) _test_step_fn, args=(next(test_iterator),))
return logits, labels, masks return logits, labels, masks
......
...@@ -130,7 +130,7 @@ def run_evaluation(strategy, test_input_fn, eval_examples, eval_features, ...@@ -130,7 +130,7 @@ def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
@tf.function @tf.function
def _run_evaluation(test_iterator): def _run_evaluation(test_iterator):
"""Runs validation steps.""" """Runs validation steps."""
res, unique_ids = strategy.experimental_run_v2( res, unique_ids = strategy.run(
_test_step_fn, args=(next(test_iterator),)) _test_step_fn, args=(next(test_iterator),))
return res, unique_ids return res, unique_ids
......
...@@ -222,16 +222,16 @@ def train( ...@@ -222,16 +222,16 @@ def train(
return mems return mems
if input_meta_data["mem_len"] > 0: if input_meta_data["mem_len"] > 0:
mem = strategy.experimental_run_v2(cache_fn) mem = strategy.run(cache_fn)
for _ in tf.range(steps): for _ in tf.range(steps):
mem = strategy.experimental_run_v2( mem = strategy.run(
_replicated_step, args=( _replicated_step, args=(
next(iterator), next(iterator),
mem, mem,
)) ))
else: else:
for _ in tf.range(steps): for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) strategy.run(_replicated_step, args=(next(iterator),))
if not run_eagerly: if not run_eagerly:
train_steps = tf.function(train_steps) train_steps = tf.function(train_steps)
......
...@@ -405,7 +405,7 @@ def run_ncf_custom_training(params, ...@@ -405,7 +405,7 @@ def run_ncf_custom_training(params,
optimizer.apply_gradients(grads) optimizer.apply_gradients(grads)
return loss return loss
per_replica_losses = strategy.experimental_run_v2( per_replica_losses = strategy.run(
step_fn, args=(next(train_iterator),)) step_fn, args=(next(train_iterator),))
mean_loss = strategy.reduce( mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
...@@ -425,7 +425,7 @@ def run_ncf_custom_training(params, ...@@ -425,7 +425,7 @@ def run_ncf_custom_training(params,
return hr_sum, hr_count return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = ( per_replica_hr_sum, per_replica_hr_count = (
strategy.experimental_run_v2( strategy.run(
step_fn, args=(next(eval_iterator),))) step_fn, args=(next(eval_iterator),)))
hr_sum = strategy.reduce( hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None) tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
......
...@@ -39,7 +39,7 @@ class AbstractTrainable(tf.Module): ...@@ -39,7 +39,7 @@ class AbstractTrainable(tf.Module):
python callbacks. This is necessary for getting good performance in TPU python callbacks. This is necessary for getting good performance in TPU
training, as the overhead for launching a multi worker tf.function may be training, as the overhead for launching a multi worker tf.function may be
large in Eager mode. It is usually encouraged to create a host training loop large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.experimental_run_v2` inside a (e.g. using a `tf.range` wrapping `strategy.run` inside a
`tf.function`) in the TPU case. For the cases that don't require host `tf.function`) in the TPU case. For the cases that don't require host
training loop to acheive peak performance, users can just implement a simple training loop to acheive peak performance, users can just implement a simple
python loop to drive each step. python loop to drive each step.
......
...@@ -87,7 +87,7 @@ class StandardTrainable(runnable.AbstractTrainable): ...@@ -87,7 +87,7 @@ class StandardTrainable(runnable.AbstractTrainable):
What a "step" consists of is up to the implementer. If using distribution What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`. to `strategy.run`.
Args: Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or iterator: A tf.nest-compatible structure of tf.data Iterator or
...@@ -163,7 +163,7 @@ class StandardEvaluable(runnable.AbstractEvaluable): ...@@ -163,7 +163,7 @@ class StandardEvaluable(runnable.AbstractEvaluable):
What a "step" consists of is up to the implementer. If using distribution What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`. to `strategy.run`.
Args: Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or iterator: A tf.nest-compatible structure of tf.data Iterator or
......
...@@ -119,7 +119,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -119,7 +119,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
return labels, prediction_outputs return labels, prediction_outputs
labels, outputs = strategy.experimental_run_v2( labels, outputs = strategy.run(
_test_step_fn, args=( _test_step_fn, args=(
next(iterator), next(iterator),
eval_steps, eval_steps,
......
...@@ -175,7 +175,7 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -175,7 +175,7 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self.train_loss.update_state(loss) self.train_loss.update_state(loss)
self.train_accuracy.update_state(labels, logits) self.train_accuracy.update_state(labels, logits)
self.strategy.experimental_run_v2(step_fn, args=(next(iterator),)) self.strategy.run(step_fn, args=(next(iterator),))
def train_loop_end(self): def train_loop_end(self):
"""See base class.""" """See base class."""
...@@ -204,7 +204,7 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -204,7 +204,7 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self.test_loss.update_state(loss) self.test_loss.update_state(loss)
self.test_accuracy.update_state(labels, logits) self.test_accuracy.update_state(labels, logits)
self.strategy.experimental_run_v2(step_fn, args=(next(iterator),)) self.strategy.run(step_fn, args=(next(iterator),))
def eval_end(self): def eval_end(self):
"""See base class.""" """See base class."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment