Commit ebc28058 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Make the python training loop handle properly with async eager.

PiperOrigin-RevId: 300127838
parent 6e070e77
...@@ -53,13 +53,17 @@ def create_loop_fn(step_fn): ...@@ -53,13 +53,17 @@ def create_loop_fn(step_fn):
""" """
try: try:
step = 0 step = 0
while (num_steps == -1 or step < num_steps): # To make sure the OutOfRangeError exception can be handled well with
outputs = step_fn(iterator) # async remote eager, we need to wrap the loop body in a `async_scope`.
if reduce_fn is not None: with tf.experimental.async_scope():
state = reduce_fn(state, outputs) while (num_steps == -1 or step < num_steps):
step += 1 outputs = step_fn(iterator)
return state if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError): except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return state return state
return loop_fn return loop_fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment