Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
704a3fda
Commit
704a3fda
authored
Apr 16, 2022
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Apr 16, 2022
Browse files
[Cleanup] Remove legacy recovery module inside the tf model garden.
PiperOrigin-RevId: 442311238
parent
36c4957b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
75 deletions
+0
-75
official/core/base_trainer.py
official/core/base_trainer.py
+0
-51
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+0
-24
No files found.
official/core/base_trainer.py
View file @
704a3fda
...
...
@@ -33,57 +33,6 @@ ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig
=
config_definitions
.
TrainerConfig
class
Recovery
:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def
__init__
(
self
,
loss_upper_bound
:
float
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
,
recovery_begin_steps
:
int
=
0
,
recovery_max_trials
:
int
=
3
):
self
.
recover_counter
=
0
self
.
recovery_begin_steps
=
recovery_begin_steps
self
.
recovery_max_trials
=
recovery_max_trials
self
.
loss_upper_bound
=
loss_upper_bound
self
.
checkpoint_manager
=
checkpoint_manager
def
should_recover
(
self
,
loss_value
,
global_step
):
if
tf
.
math
.
is_nan
(
loss_value
):
return
True
if
(
global_step
>=
self
.
recovery_begin_steps
and
loss_value
>
self
.
loss_upper_bound
):
return
True
return
False
def
maybe_recover
(
self
,
loss_value
,
global_step
):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if
not
self
.
should_recover
(
loss_value
,
global_step
):
return
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
"The loss value is NaN or out of range after training loop and "
f
"this happens
{
self
.
recover_counter
}
times."
)
# Loads the previous good checkpoint.
checkpoint_path
=
self
.
checkpoint_manager
.
restore_or_initialize
()
logging
.
warning
(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d."
,
checkpoint_path
,
loss_value
,
global_step
)
class
_AsyncTrainer
(
orbit
.
StandardTrainer
,
orbit
.
StandardEvaluator
):
"""Trainer class for both sync and async Strategy."""
...
...
official/core/base_trainer_test.py
View file @
704a3fda
...
...
@@ -150,30 +150,6 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
return
self
.
eval_global_step
.
numpy
()
class
RecoveryTest
(
tf
.
test
.
TestCase
):
def
test_recovery_module
(
self
):
ckpt
=
tf
.
train
.
Checkpoint
(
v
=
tf
.
Variable
(
1
,
dtype
=
tf
.
int32
))
model_dir
=
self
.
get_temp_dir
()
manager
=
tf
.
train
.
CheckpointManager
(
ckpt
,
model_dir
,
max_to_keep
=
1
)
recovery_module
=
trainer_lib
.
Recovery
(
loss_upper_bound
=
1.0
,
checkpoint_manager
=
manager
,
recovery_begin_steps
=
1
,
recovery_max_trials
=
1
)
self
.
assertFalse
(
recovery_module
.
should_recover
(
1.1
,
0
))
self
.
assertFalse
(
recovery_module
.
should_recover
(
0.1
,
1
))
self
.
assertTrue
(
recovery_module
.
should_recover
(
1.1
,
2
))
# First triggers the recovery once.
recovery_module
.
maybe_recover
(
1.1
,
10
)
# Second time, it raises.
with
self
.
assertRaisesRegex
(
RuntimeError
,
'The loss value is NaN .*'
):
recovery_module
.
maybe_recover
(
1.1
,
10
)
class
TrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment