".github/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "51f89b6c8b850091771514656c7f263bba012603"
Commit a64acf7d authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Clear out the outer name scope so the ops created inside `tf.while_loop`

don't get "while/" as name prefix.

PiperOrigin-RevId: 351689649
parent de319a78
...@@ -112,7 +112,10 @@ def create_tf_while_loop_fn(step_fn): ...@@ -112,7 +112,10 @@ def create_tf_while_loop_fn(step_fn):
"cause unnecessary retracing when wrapped by `tf.function`.") "cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps): for _ in tf.range(num_steps):
step_fn(iterator) # Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
step_fn(iterator)
return loop_fn return loop_fn
...@@ -157,15 +160,18 @@ def create_tf_while_loop_fn_with_state(step_fn): ...@@ -157,15 +160,18 @@ def create_tf_while_loop_fn_with_state(step_fn):
"cause unnecessary retracing when wrapped by `tf.function`.") "cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps): for _ in tf.range(num_steps):
# Relax the shapes within the loop, so the shape of `state` can change # Clear out the outer name scope so the ops created inside `tf.while_loop`
# across iterations. This is useful to aggregate outputs from each step # don't get "while/" as name prefix.
# and concat to `state`. with tf.name_scope(""):
tf.autograph.experimental.set_loop_options( # Relax the shapes within the loop, so the shape of `state` can change
shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank)) # across iterations. This is useful to aggregate outputs from each step
for t in tf.nest.flatten(state) # and concat to `state`.
if tf.is_tensor(t)]) tf.autograph.experimental.set_loop_options(
outputs = step_fn(iterator) shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank))
state = reduce_fn(state, outputs) for t in tf.nest.flatten(state)
if tf.is_tensor(t)])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state return state
return loop_fn_with_state return loop_fn_with_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment