Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
90dedf26
Commit
90dedf26
authored
Aug 09, 2022
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Aug 09, 2022
Browse files
Allow `steps_per_loop` in Controller to be passed as a callable.
PiperOrigin-RevId: 466412169
parent
db19ab9b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
11 deletions
+47
-11
orbit/controller.py
orbit/controller.py
+21
-11
orbit/controller_test.py
orbit/controller_test.py
+26
-0
No files found.
orbit/controller.py
View file @
90dedf26
...
...
@@ -94,7 +94,7 @@ class Controller:
train_actions
:
Optional
[
Iterable
[
Action
]]
=
None
,
eval_actions
:
Optional
[
Iterable
[
Action
]]
=
None
,
# Train related
steps_per_loop
:
Optional
[
int
]
=
None
,
steps_per_loop
:
Optional
[
Union
[
int
,
Callable
[[
int
],
int
]]
]
=
None
,
checkpoint_manager
:
Optional
[
tf
.
train
.
CheckpointManager
]
=
None
,
# Summary related
summary_interval
:
Optional
[
int
]
=
None
,
...
...
@@ -130,8 +130,11 @@ class Controller:
output of `trainer.train`.
eval_actions: Optional `orbit.Action`s to call after each evaluation.
These will be called with the output of `evaluator.evaluate`.
steps_per_loop: The number of steps to run in each inner loop of training
(passed as the `num_steps` parameter of `trainer.train`).
steps_per_loop: Optional integer to indicate the number of steps to run in
each inner loop of training (passed as the `num_steps` parameter of
`trainer.train`). It can be also a callable which takes the current
global step value as input and returns the number of steps to run as
output.
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
provided and there are checkpoints in the associated model directory,
the model will be restored from the most recent checkpoint inside this
...
...
@@ -152,7 +155,7 @@ class Controller:
Raises:
ValueError: If both `trainer` and `evaluator` are `None`.
ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `steps_per_loop` is not a positive integer
or a callable
.
ValueError: If `summary_interval` is not a positive integer or is not
divisible by `steps_per_loop`.
"""
...
...
@@ -163,15 +166,18 @@ class Controller:
if
steps_per_loop
is
None
:
raise
ValueError
(
"`steps_per_loop` is required when `trainer` is provided."
)
elif
not
isinstance
(
steps_per_loop
,
int
)
or
steps_per_loop
<
1
:
elif
not
callable
(
steps_per_loop
)
and
(
not
isinstance
(
steps_per_loop
,
int
)
or
steps_per_loop
<
1
):
raise
ValueError
(
f
"`steps_per_loop` (
{
steps_per_loop
}
) must be a positive integer."
)
f
"`steps_per_loop` (
{
steps_per_loop
}
) must be a positive integer "
"or a callable."
)
if
summary_interval
is
not
None
:
if
summary_interval
<=
0
:
raise
ValueError
(
f
"`summary_interval` (
{
summary_interval
}
) must be larger than 0."
)
elif
summary_interval
%
steps_per_loop
!=
0
:
elif
not
callable
(
steps_per_loop
)
and
(
summary_interval
%
steps_per_loop
!=
0
):
raise
ValueError
(
f
"`summary interval` (
{
summary_interval
}
) must be a multiple "
f
"of `steps_per_loop` (
{
steps_per_loop
}
)."
)
...
...
@@ -192,10 +198,10 @@ class Controller:
if
self
.
trainer
is
not
None
:
self
.
step_timer
=
None
self
.
steps_per_loop
=
steps_per_loop
self
.
summary_interval
=
summary_interval
self
.
summary_manager
=
utils
.
SummaryManager
(
summary_dir
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
self
.
_steps_per_loop
=
steps_per_loop
if
self
.
evaluator
is
not
None
:
eval_summary_dir
=
eval_summary_dir
or
summary_dir
...
...
@@ -316,9 +322,6 @@ class Controller:
results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
self
.
_require
(
"trainer"
,
for_method
=
"train_and_evaluate"
)
self
.
_require
(
"evaluator"
,
for_method
=
"train_and_evaluate"
)
...
...
@@ -410,6 +413,13 @@ class Controller:
self
.
_require
(
"checkpoint_manager"
,
for_method
=
"save_checkpoint"
)
self
.
_maybe_save_checkpoint
(
check_interval
=
False
)
@
property
def
steps_per_loop
(
self
):
"""Returns current steps_per_loop value in a training loop."""
if
callable
(
self
.
_steps_per_loop
):
return
self
.
_steps_per_loop
(
self
.
global_step
.
numpy
())
return
self
.
_steps_per_loop
def
_train_n_steps
(
self
,
num_steps
:
int
):
"""Runs training for `num_steps` steps.
...
...
orbit/controller_test.py
View file @
90dedf26
...
...
@@ -770,6 +770,32 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIn
(
"eval_loss"
,
output
)
self
.
assertGreaterEqual
(
output
[
"eval_loss"
],
0
)
def
test_step_per_loop_callable
(
self
):
test_runner
=
TestRunner
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
,
optimizer
=
test_runner
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
,
checkpoint_interval
=
10
)
def
steps_per_loop_fn
(
global_step
):
if
global_step
>
4
:
return
4
return
2
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
steps_per_loop_fn
,
checkpoint_manager
=
checkpoint_manager
,
)
test_controller
.
train
(
steps
=
10
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment