Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
90dedf26
Commit
90dedf26
authored
Aug 09, 2022
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Aug 09, 2022
Browse files
Allow `steps_per_loop` in Controller to be passed as a callable.
PiperOrigin-RevId: 466412169
parent
db19ab9b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
11 deletions
+47
-11
orbit/controller.py
orbit/controller.py
+21
-11
orbit/controller_test.py
orbit/controller_test.py
+26
-0
No files found.
orbit/controller.py
View file @
90dedf26
...
@@ -94,7 +94,7 @@ class Controller:
...
@@ -94,7 +94,7 @@ class Controller:
train_actions
:
Optional
[
Iterable
[
Action
]]
=
None
,
train_actions
:
Optional
[
Iterable
[
Action
]]
=
None
,
eval_actions
:
Optional
[
Iterable
[
Action
]]
=
None
,
eval_actions
:
Optional
[
Iterable
[
Action
]]
=
None
,
# Train related
# Train related
steps_per_loop
:
Optional
[
int
]
=
None
,
steps_per_loop
:
Optional
[
Union
[
int
,
Callable
[[
int
],
int
]]
]
=
None
,
checkpoint_manager
:
Optional
[
tf
.
train
.
CheckpointManager
]
=
None
,
checkpoint_manager
:
Optional
[
tf
.
train
.
CheckpointManager
]
=
None
,
# Summary related
# Summary related
summary_interval
:
Optional
[
int
]
=
None
,
summary_interval
:
Optional
[
int
]
=
None
,
...
@@ -130,8 +130,11 @@ class Controller:
...
@@ -130,8 +130,11 @@ class Controller:
output of `trainer.train`.
output of `trainer.train`.
eval_actions: Optional `orbit.Action`s to call after each evaluation.
eval_actions: Optional `orbit.Action`s to call after each evaluation.
These will be called with the output of `evaluator.evaluate`.
These will be called with the output of `evaluator.evaluate`.
steps_per_loop: The number of steps to run in each inner loop of training
steps_per_loop: Optional integer to indicate the number of steps to run in
(passed as the `num_steps` parameter of `trainer.train`).
each inner loop of training (passed as the `num_steps` parameter of
`trainer.train`). It can be also a callable which takes the current
global step value as input and returns the number of steps to run as
output.
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
provided and there are checkpoints in the associated model directory,
provided and there are checkpoints in the associated model directory,
the model will be restored from the most recent checkpoint inside this
the model will be restored from the most recent checkpoint inside this
...
@@ -152,7 +155,7 @@ class Controller:
...
@@ -152,7 +155,7 @@ class Controller:
Raises:
Raises:
ValueError: If both `trainer` and `evaluator` are `None`.
ValueError: If both `trainer` and `evaluator` are `None`.
ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `steps_per_loop` is not a positive integer
or a callable
.
ValueError: If `summary_interval` is not a positive integer or is not
ValueError: If `summary_interval` is not a positive integer or is not
divisible by `steps_per_loop`.
divisible by `steps_per_loop`.
"""
"""
...
@@ -163,15 +166,18 @@ class Controller:
...
@@ -163,15 +166,18 @@ class Controller:
if
steps_per_loop
is
None
:
if
steps_per_loop
is
None
:
raise
ValueError
(
raise
ValueError
(
"`steps_per_loop` is required when `trainer` is provided."
)
"`steps_per_loop` is required when `trainer` is provided."
)
elif
not
isinstance
(
steps_per_loop
,
int
)
or
steps_per_loop
<
1
:
elif
not
callable
(
steps_per_loop
)
and
(
not
isinstance
(
steps_per_loop
,
int
)
or
steps_per_loop
<
1
):
raise
ValueError
(
raise
ValueError
(
f
"`steps_per_loop` (
{
steps_per_loop
}
) must be a positive integer."
)
f
"`steps_per_loop` (
{
steps_per_loop
}
) must be a positive integer "
"or a callable."
)
if
summary_interval
is
not
None
:
if
summary_interval
is
not
None
:
if
summary_interval
<=
0
:
if
summary_interval
<=
0
:
raise
ValueError
(
raise
ValueError
(
f
"`summary_interval` (
{
summary_interval
}
) must be larger than 0."
)
f
"`summary_interval` (
{
summary_interval
}
) must be larger than 0."
)
elif
summary_interval
%
steps_per_loop
!=
0
:
elif
not
callable
(
steps_per_loop
)
and
(
summary_interval
%
steps_per_loop
!=
0
):
raise
ValueError
(
raise
ValueError
(
f
"`summary interval` (
{
summary_interval
}
) must be a multiple "
f
"`summary interval` (
{
summary_interval
}
) must be a multiple "
f
"of `steps_per_loop` (
{
steps_per_loop
}
)."
)
f
"of `steps_per_loop` (
{
steps_per_loop
}
)."
)
...
@@ -192,10 +198,10 @@ class Controller:
...
@@ -192,10 +198,10 @@ class Controller:
if
self
.
trainer
is
not
None
:
if
self
.
trainer
is
not
None
:
self
.
step_timer
=
None
self
.
step_timer
=
None
self
.
steps_per_loop
=
steps_per_loop
self
.
summary_interval
=
summary_interval
self
.
summary_interval
=
summary_interval
self
.
summary_manager
=
utils
.
SummaryManager
(
self
.
summary_manager
=
utils
.
SummaryManager
(
summary_dir
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
summary_dir
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
self
.
_steps_per_loop
=
steps_per_loop
if
self
.
evaluator
is
not
None
:
if
self
.
evaluator
is
not
None
:
eval_summary_dir
=
eval_summary_dir
or
summary_dir
eval_summary_dir
=
eval_summary_dir
or
summary_dir
...
@@ -316,9 +322,6 @@ class Controller:
...
@@ -316,9 +322,6 @@ class Controller:
results in a shorter inner loop than specified by `steps_per_loop`
results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
setting. If None, evaluation will only be performed after training is
complete.
complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
"""
self
.
_require
(
"trainer"
,
for_method
=
"train_and_evaluate"
)
self
.
_require
(
"trainer"
,
for_method
=
"train_and_evaluate"
)
self
.
_require
(
"evaluator"
,
for_method
=
"train_and_evaluate"
)
self
.
_require
(
"evaluator"
,
for_method
=
"train_and_evaluate"
)
...
@@ -410,6 +413,13 @@ class Controller:
...
@@ -410,6 +413,13 @@ class Controller:
self
.
_require
(
"checkpoint_manager"
,
for_method
=
"save_checkpoint"
)
self
.
_require
(
"checkpoint_manager"
,
for_method
=
"save_checkpoint"
)
self
.
_maybe_save_checkpoint
(
check_interval
=
False
)
self
.
_maybe_save_checkpoint
(
check_interval
=
False
)
@
property
def
steps_per_loop
(
self
):
"""Returns current steps_per_loop value in a training loop."""
if
callable
(
self
.
_steps_per_loop
):
return
self
.
_steps_per_loop
(
self
.
global_step
.
numpy
())
return
self
.
_steps_per_loop
def
_train_n_steps
(
self
,
num_steps
:
int
):
def
_train_n_steps
(
self
,
num_steps
:
int
):
"""Runs training for `num_steps` steps.
"""Runs training for `num_steps` steps.
...
...
orbit/controller_test.py
View file @
90dedf26
...
@@ -770,6 +770,32 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -770,6 +770,32 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIn
(
"eval_loss"
,
output
)
self
.
assertIn
(
"eval_loss"
,
output
)
self
.
assertGreaterEqual
(
output
[
"eval_loss"
],
0
)
self
.
assertGreaterEqual
(
output
[
"eval_loss"
],
0
)
def
test_step_per_loop_callable
(
self
):
test_runner
=
TestRunner
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
,
optimizer
=
test_runner
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
,
checkpoint_interval
=
10
)
def
steps_per_loop_fn
(
global_step
):
if
global_step
>
4
:
return
4
return
2
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
steps_per_loop_fn
,
checkpoint_manager
=
checkpoint_manager
,
)
test_controller
.
train
(
steps
=
10
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment