Commit 7ef4a501 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Fix a bug that `loop_fns.create_tf_while_loop_fn` doesn't handle nested structure states.

PiperOrigin-RevId: 395317528
parent 55333759
......@@ -91,10 +91,10 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self):
return tf.constant((0.0,))
return {"logits": tf.constant((0.0,))}
def eval_reduce(self, state, step_outputs):
state = tf.concat([state, step_outputs], 0)
state["logits"] = tf.concat([state["logits"], step_outputs], 0)
return state
def eval_step(self, iterator):
......@@ -107,7 +107,7 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
self.strategy.run(replica_step, args=(next(iterator),)))
def eval_end(self, outputs):
return tf.reduce_sum(outputs)
return tf.reduce_sum(outputs["logits"])
class StandardRunnerTest(parameterized.TestCase):
......
......@@ -159,6 +159,21 @@ def create_tf_while_loop_fn_with_state(step_fn):
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
def _get_relaxed_tensor_shape(t):
"""Returns a `TensorShape` with all `None` dimensions."""
if not tf.is_tensor(t):
return None
shape = t.shape
if shape.rank is not None and shape.rank > 0:
return tf.TensorShape([None] * shape.rank)
return shape
def _get_relaxed_shape_structure(s):
"""Returns the relaxed shape of the input nested structure `s`."""
return tf.nest.pack_sequence_as(
state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])
for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
......@@ -167,9 +182,7 @@ def create_tf_while_loop_fn_with_state(step_fn):
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank))
for t in tf.nest.flatten(state)
if tf.is_tensor(t)])
shape_invariants=[(state, _get_relaxed_shape_structure(state))])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment