Commit ebc28058 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Make the python training loop handle properly with async eager.

PiperOrigin-RevId: 300127838
parent 6e070e77
...@@ -53,6 +53,9 @@ def create_loop_fn(step_fn): ...@@ -53,6 +53,9 @@ def create_loop_fn(step_fn):
""" """
try: try:
step = 0 step = 0
# To make sure the OutOfRangeError exception can be handled well with
# async remote eager, we need to wrap the loop body in a `async_scope`.
with tf.experimental.async_scope():
while (num_steps == -1 or step < num_steps): while (num_steps == -1 or step < num_steps):
outputs = step_fn(iterator) outputs = step_fn(iterator)
if reduce_fn is not None: if reduce_fn is not None:
...@@ -60,6 +63,7 @@ def create_loop_fn(step_fn): ...@@ -60,6 +63,7 @@ def create_loop_fn(step_fn):
step += 1 step += 1
return state return state
except (StopIteration, tf.errors.OutOfRangeError): except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return state return state
return loop_fn return loop_fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment