"docs/tutorials/evaluation.md" did not exist on "5b3792fc3ef9ab6a6f8f30634ab2e52fb0941af3"
Commit 7ef4a501 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Fix a bug that `loop_fns.create_tf_while_loop_fn` doesn't handle nested structure states.

PiperOrigin-RevId: 395317528
parent 55333759
...@@ -91,10 +91,10 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator): ...@@ -91,10 +91,10 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
super().__init__(eval_dataset=dataset, options=options) super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self): def eval_begin(self):
return tf.constant((0.0,)) return {"logits": tf.constant((0.0,))}
def eval_reduce(self, state, step_outputs): def eval_reduce(self, state, step_outputs):
state = tf.concat([state, step_outputs], 0) state["logits"] = tf.concat([state["logits"], step_outputs], 0)
return state return state
def eval_step(self, iterator): def eval_step(self, iterator):
...@@ -107,7 +107,7 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator): ...@@ -107,7 +107,7 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
self.strategy.run(replica_step, args=(next(iterator),))) self.strategy.run(replica_step, args=(next(iterator),)))
def eval_end(self, outputs): def eval_end(self, outputs):
return tf.reduce_sum(outputs) return tf.reduce_sum(outputs["logits"])
class StandardRunnerTest(parameterized.TestCase): class StandardRunnerTest(parameterized.TestCase):
......
...@@ -159,6 +159,21 @@ def create_tf_while_loop_fn_with_state(step_fn): ...@@ -159,6 +159,21 @@ def create_tf_while_loop_fn_with_state(step_fn):
"`num_steps` should be a `tf.Tensor`. Passing a Python value can " "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.") "cause unnecessary retracing when wrapped by `tf.function`.")
def _get_relaxed_tensor_shape(t):
"""Returns a `TensorShape` with all `None` dimensions."""
if not tf.is_tensor(t):
return None
shape = t.shape
if shape.rank is not None and shape.rank > 0:
return tf.TensorShape([None] * shape.rank)
return shape
def _get_relaxed_shape_structure(s):
"""Returns the relaxed shape of the input nested structure `s`."""
return tf.nest.pack_sequence_as(
state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])
for _ in tf.range(num_steps): for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop` # Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix. # don't get "while/" as name prefix.
...@@ -167,9 +182,7 @@ def create_tf_while_loop_fn_with_state(step_fn): ...@@ -167,9 +182,7 @@ def create_tf_while_loop_fn_with_state(step_fn):
# across iterations. This is useful to aggregate outputs from each step # across iterations. This is useful to aggregate outputs from each step
# and concat to `state`. # and concat to `state`.
tf.autograph.experimental.set_loop_options( tf.autograph.experimental.set_loop_options(
shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank)) shape_invariants=[(state, _get_relaxed_shape_structure(state))])
for t in tf.nest.flatten(state)
if tf.is_tensor(t)])
outputs = step_fn(iterator) outputs = step_fn(iterator)
state = reduce_fn(state, outputs) state = reduce_fn(state, outputs)
return state return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment