Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7ef4a501
Commit
7ef4a501
authored
Sep 07, 2021
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Sep 07, 2021
Browse files
Fix a bug that `loop_fns.create_tf_while_loop_fn` doesn't handle nested structure states.
PiperOrigin-RevId: 395317528
parent
55333759
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
6 deletions
+19
-6
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+3
-3
orbit/utils/loop_fns.py
orbit/utils/loop_fns.py
+16
-3
No files found.
orbit/standard_runner_test.py
View file @
7ef4a501
...
...
@@ -91,10 +91,10 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
super
().
__init__
(
eval_dataset
=
dataset
,
options
=
options
)
def
eval_begin
(
self
):
return
tf
.
constant
((
0.0
,))
return
{
"logits"
:
tf
.
constant
((
0.0
,))
}
def
eval_reduce
(
self
,
state
,
step_outputs
):
state
=
tf
.
concat
([
state
,
step_outputs
],
0
)
state
[
"logits"
]
=
tf
.
concat
([
state
[
"logits"
]
,
step_outputs
],
0
)
return
state
def
eval_step
(
self
,
iterator
):
...
...
@@ -107,7 +107,7 @@ class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
self
.
strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),)))
def
eval_end
(
self
,
outputs
):
return
tf
.
reduce_sum
(
outputs
)
return
tf
.
reduce_sum
(
outputs
[
"logits"
]
)
class
StandardRunnerTest
(
parameterized
.
TestCase
):
...
...
orbit/utils/loop_fns.py
View file @
7ef4a501
...
...
@@ -159,6 +159,21 @@ def create_tf_while_loop_fn_with_state(step_fn):
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`."
)
def
_get_relaxed_tensor_shape
(
t
):
"""Returns a `TensorShape` with all `None` dimensions."""
if
not
tf
.
is_tensor
(
t
):
return
None
shape
=
t
.
shape
if
shape
.
rank
is
not
None
and
shape
.
rank
>
0
:
return
tf
.
TensorShape
([
None
]
*
shape
.
rank
)
return
shape
def
_get_relaxed_shape_structure
(
s
):
"""Returns the relaxed shape of the input nested structure `s`."""
return
tf
.
nest
.
pack_sequence_as
(
state
,
[
_get_relaxed_tensor_shape
(
t
)
for
t
in
tf
.
nest
.
flatten
(
s
)])
for
_
in
tf
.
range
(
num_steps
):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
...
...
@@ -167,9 +182,7 @@ def create_tf_while_loop_fn_with_state(step_fn):
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf
.
autograph
.
experimental
.
set_loop_options
(
shape_invariants
=
[(
t
,
tf
.
TensorShape
([
None
]
*
t
.
shape
.
rank
))
for
t
in
tf
.
nest
.
flatten
(
state
)
if
tf
.
is_tensor
(
t
)])
shape_invariants
=
[(
state
,
_get_relaxed_shape_structure
(
state
))])
outputs
=
step_fn
(
iterator
)
state
=
reduce_fn
(
state
,
outputs
)
return
state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment