Commit 2d7223ad authored by Ken Franko's avatar Ken Franko Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 301277545
parent 64a48653
......@@ -321,7 +321,7 @@ def run_customized_training_loop(
'retracing.')
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
strategy.run(_replicated_step, args=(next(iterator),))
def train_single_step(iterator):
"""Performs a distributed training step.
......@@ -332,7 +332,7 @@ def run_customized_training_loop(
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
strategy.run(_replicated_step, args=(next(iterator),))
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
......@@ -345,7 +345,7 @@ def run_customized_training_loop(
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
strategy.run(_test_step_fn, args=(next(iterator),))
if not run_eagerly:
train_single_step = tf.function(train_single_step)
......
......@@ -243,10 +243,10 @@ class DistributedExecutor(object):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
per_replica_losses = strategy.experimental_run_v2(
per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.experimental_run_v2(
per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
......@@ -278,7 +278,7 @@ class DistributedExecutor(object):
metric.update_state(labels, model_outputs)
return labels, model_outputs
return strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
return strategy.run(_test_step_fn, args=(next(iterator),))
return test_step
......
......@@ -267,7 +267,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
model_outputs = trained_model(inputs, training=False)
return model_outputs, labels
outputs, labels = strategy.experimental_run_v2(
outputs, labels = strategy.run(
_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results,
......
......@@ -194,7 +194,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run_v2(
outputs = strategy.run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
......
......@@ -280,7 +280,7 @@ class TransformerTask(object):
for _ in tf.range(steps):
train_loss_metric.reset_states()
self.distribution_strategy.experimental_run_v2(
self.distribution_strategy.run(
_step_fn, args=(next(iterator),))
cased_score, uncased_score = None, None
......
......@@ -132,7 +132,7 @@ def translate_file(model,
val_outputs, _ = model([val_inputs], training=False)
return tag, val_outputs
return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,))
return distribution_strategy.run(_step_fn, args=(inputs,))
translations = []
if distribution_strategy:
......@@ -151,7 +151,7 @@ def translate_file(model,
replica_id = replica_context.replica_id_in_sync_group
return replica_id, text[replica_id]
text = distribution_strategy.experimental_run_v2(text_as_per_replica)
text = distribution_strategy.run(text_as_per_replica)
outputs = distribution_strategy.experimental_local_results(
predict_step(text))
tags, unordered_val_outputs = outputs[0]
......
......@@ -87,7 +87,7 @@ def run_evaluation(strategy,
@tf.function
def _run_evaluation(test_iterator):
"""Runs validation steps."""
logits, labels, masks = strategy.experimental_run_v2(
logits, labels, masks = strategy.run(
_test_step_fn, args=(next(test_iterator),))
return logits, labels, masks
......
......@@ -130,7 +130,7 @@ def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
@tf.function
def _run_evaluation(test_iterator):
"""Runs validation steps."""
res, unique_ids = strategy.experimental_run_v2(
res, unique_ids = strategy.run(
_test_step_fn, args=(next(test_iterator),))
return res, unique_ids
......
......@@ -222,16 +222,16 @@ def train(
return mems
if input_meta_data["mem_len"] > 0:
mem = strategy.experimental_run_v2(cache_fn)
mem = strategy.run(cache_fn)
for _ in tf.range(steps):
mem = strategy.experimental_run_v2(
mem = strategy.run(
_replicated_step, args=(
next(iterator),
mem,
))
else:
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
strategy.run(_replicated_step, args=(next(iterator),))
if not run_eagerly:
train_steps = tf.function(train_steps)
......
......@@ -405,7 +405,7 @@ def run_ncf_custom_training(params,
optimizer.apply_gradients(grads)
return loss
per_replica_losses = strategy.experimental_run_v2(
per_replica_losses = strategy.run(
step_fn, args=(next(train_iterator),))
mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
......@@ -425,7 +425,7 @@ def run_ncf_custom_training(params,
return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = (
strategy.experimental_run_v2(
strategy.run(
step_fn, args=(next(eval_iterator),)))
hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
......
......@@ -39,7 +39,7 @@ class AbstractTrainable(tf.Module):
python callbacks. This is necessary for getting good performance in TPU
training, as the overhead for launching a multi worker tf.function may be
large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.experimental_run_v2` inside a
(e.g. using a `tf.range` wrapping `strategy.run` inside a
`tf.function`) in the TPU case. For the cases that don't require host
training loop to acheive peak performance, users can just implement a simple
python loop to drive each step.
......
......@@ -87,7 +87,7 @@ class StandardTrainable(runnable.AbstractTrainable):
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`.
to `strategy.run`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
......@@ -163,7 +163,7 @@ class StandardEvaluable(runnable.AbstractEvaluable):
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`.
to `strategy.run`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
......
......@@ -119,7 +119,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
return labels, prediction_outputs
labels, outputs = strategy.experimental_run_v2(
labels, outputs = strategy.run(
_test_step_fn, args=(
next(iterator),
eval_steps,
......
......@@ -175,7 +175,7 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self.train_loss.update_state(loss)
self.train_accuracy.update_state(labels, logits)
self.strategy.experimental_run_v2(step_fn, args=(next(iterator),))
self.strategy.run(step_fn, args=(next(iterator),))
def train_loop_end(self):
"""See base class."""
......@@ -204,7 +204,7 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self.test_loss.update_state(loss)
self.test_accuracy.update_state(labels, logits)
self.strategy.experimental_run_v2(step_fn, args=(next(iterator),))
self.strategy.run(step_fn, args=(next(iterator),))
def eval_end(self):
"""See base class."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment