Commit a64acf7d authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Clear out the outer name scope so the ops created inside `tf.while_loop`

don't get "while/" as name prefix.

PiperOrigin-RevId: 351689649
parent de319a78
......@@ -112,7 +112,10 @@ def create_tf_while_loop_fn(step_fn):
"cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps):
step_fn(iterator)
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
step_fn(iterator)
return loop_fn
......@@ -157,15 +160,18 @@ def create_tf_while_loop_fn_with_state(step_fn):
"cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank))
for t in tf.nest.flatten(state)
if tf.is_tensor(t)])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank))
for t in tf.nest.flatten(state)
if tf.is_tensor(t)])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state
return loop_fn_with_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment