Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
08189186
Commit
08189186
authored
Aug 05, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 389019529
parent
d088f0d5
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
196 additions
and
82 deletions
+196
-82
official/core/actions.py
official/core/actions.py
+69
-3
official/core/actions_test.py
official/core/actions_test.py
+35
-9
official/core/base_trainer.py
official/core/base_trainer.py
+1
-5
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+0
-56
official/core/train_lib.py
official/core/train_lib.py
+2
-3
official/core/train_lib_test.py
official/core/train_lib_test.py
+89
-1
official/recommendation/ranking/train.py
official/recommendation/ranking/train.py
+0
-5
No files found.
official/core/actions.py
View file @
08189186
...
...
@@ -16,6 +16,7 @@
import
os
from
typing
import
List
from
absl
import
logging
import
gin
import
orbit
...
...
@@ -119,6 +120,58 @@ class EMACheckpointing:
self
.
_optimizer
.
swap_weights
()
class
RecoveryAction
:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def
__init__
(
self
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
self
.
checkpoint_manager
=
checkpoint_manager
def
__call__
(
self
,
_
):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path
=
self
.
checkpoint_manager
.
restore_or_initialize
()
logging
.
warning
(
'Recovering the model from checkpoint: %s.'
,
checkpoint_path
)
class
RecoveryCondition
:
"""Recovery Condition."""
def
__init__
(
self
,
global_step
:
tf
.
Variable
,
loss_upper_bound
:
float
,
recovery_begin_steps
:
int
=
0
,
recovery_max_trials
:
int
=
3
):
self
.
recover_counter
=
0
self
.
recovery_begin_steps
=
recovery_begin_steps
self
.
recovery_max_trials
=
recovery_max_trials
self
.
loss_upper_bound
=
loss_upper_bound
self
.
global_step
=
global_step
def
__call__
(
self
,
outputs
:
orbit
.
runner
.
Output
):
loss_value
=
outputs
[
'training_loss'
]
if
tf
.
math
.
is_nan
(
loss_value
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
'The loss value is NaN after training loop and it happens %d times.'
%
self
.
recover_counter
)
return
True
if
(
self
.
global_step
>=
self
.
recovery_begin_steps
and
loss_value
>
self
.
loss_upper_bound
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
f
'The loss value is
{
loss_value
}
, which is larger than the bound
{
self
.
loss_upper_bound
}
, happens
{
self
.
recover_counter
}
times.'
)
return
True
return
False
@
gin
.
configurable
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
...
...
@@ -140,9 +193,10 @@ def get_eval_actions(
@
gin
.
configurable
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
)
->
List
[
orbit
.
Action
]:
"""Gets train actions for TFM trainer."""
train_actions
=
[]
# Adds pruning callback actions.
...
...
@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
if
params
.
trainer
.
recovery_max_trials
>=
0
:
recovery_condition
=
RecoveryCondition
(
global_step
=
trainer
.
global_step
,
loss_upper_bound
=
params
.
trainer
.
loss_upper_bound
,
recovery_begin_steps
=
params
.
trainer
.
recovery_begin_steps
,
recovery_max_trials
=
params
.
trainer
.
recovery_max_trials
,
)
recover_action
=
orbit
.
actions
.
ConditionalAction
(
condition
=
recovery_condition
,
action
=
RecoveryAction
(
checkpoint_manager
),
)
train_actions
.
append
(
recover_action
)
return
train_actions
official/core/actions_test.py
View file @
08189186
...
...
@@ -17,6 +17,8 @@
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
...
...
@@ -35,17 +37,14 @@ class TestModel(tf.Module):
return
self
.
value
def
all_strategy_combinations
():
return
combinations
.
combine
(
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],)
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
],))
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
directory
=
self
.
create_tempdir
()
...
...
@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(),
0
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],))
def
test_recovery_condition
(
self
,
distribution
):
with
distribution
.
scope
():
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
0.6
}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
tf
.
constant
([
np
.
nan
],
tf
.
float32
)}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/core/base_trainer.py
View file @
08189186
...
...
@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def
add_recovery
(
self
,
params
:
TrainerConfig
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
if
params
.
recovery_max_trials
>=
0
:
...
...
@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer):
def
train_loop_end
(
self
):
"""See base class."""
self
.
join
()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if
self
.
_recovery
:
self
.
_recovery
.
maybe_recover
(
self
.
train_loss
.
result
().
numpy
(),
self
.
global_step
.
numpy
())
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
...
...
official/core/base_trainer_test.py
View file @
08189186
...
...
@@ -19,7 +19,6 @@ import os
import
sys
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
portpicker
import
tensorflow
as
tf
...
...
@@ -337,61 +336,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
'best_ckpt'
,
'info.json'
)))
def
test_recovery
(
self
):
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
model_dir
=
self
.
get_temp_dir
()
trainer
=
self
.
create_test_trainer
(
config
,
model_dir
=
model_dir
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
self
.
get_temp_dir
(),
max_to_keep
=
2
)
checkpoint_manager
.
save
()
trainer
.
add_recovery
(
config
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
before_weights
=
trainer
.
model
.
get_weights
()
_
=
trainer
.
train
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
# The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
after_weights
=
trainer
.
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
# Let's the loss be NaN and max_trials = 0 to see RuntimeError.
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
recovery_max_trials
=
0
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
task
=
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
np
.
nan
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
trainer
=
trainer_lib
.
Trainer
(
config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
task
.
create_optimizer
(
config
.
trainer
.
optimizer_config
,
config
.
runtime
))
trainer
.
add_recovery
(
config
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
with
self
.
assertRaises
(
RuntimeError
):
_
=
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
def
test_model_with_compiled_loss
(
self
):
task
=
mock_task
.
MockTask
()
model
=
task
.
build_model
()
...
...
official/core/train_lib.py
View file @
08189186
...
...
@@ -87,8 +87,6 @@ def run_experiment(
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
# Adds recovery handling.
trainer
.
add_recovery
(
params
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
else
:
checkpoint_manager
=
None
...
...
@@ -105,7 +103,8 @@ def run_experiment(
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
),
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
,
checkpoint_manager
=
checkpoint_manager
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
...
...
official/core/train_lib_test.py
View file @
08189186
...
...
@@ -19,6 +19,7 @@ import os
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
...
...
@@ -30,6 +31,7 @@ from official.common import registry_imports
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
...
...
@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
print
(
logs
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
,
'train_and_eval'
],
))
def
test_recovery_nan_error
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task
=
mock_task
.
MockTask
(
params
.
task
,
logging_dir
=
model_dir
)
# Set the loss to NaN to trigger RunTimeError.
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
np
.
nan
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
with
self
.
assertRaises
(
RuntimeError
):
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
],
))
def
test_recovery
(
self
,
distribution_strategy
,
flag_mode
):
loss_threshold
=
1.0
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
params
.
trainer
.
loss_upper_bound
=
loss_threshold
params
.
trainer
.
recovery_max_trials
=
1
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
# Saves a checkpoint for reference.
model
=
task
.
build_model
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
get_temp_dir
(),
max_to_keep
=
2
)
checkpoint_manager
.
save
()
before_weights
=
model
.
get_weights
()
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
loss_threshold
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
model
,
_
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
after_weights
=
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
def
test_parse_configuration
(
self
):
model_dir
=
self
.
get_temp_dir
()
...
...
official/recommendation/ranking/train.py
View file @
08189186
...
...
@@ -43,11 +43,6 @@ class RankingTrainer(base_trainer.Trainer):
def
train_loop_end
(
self
)
->
Dict
[
str
,
float
]:
"""See base class."""
self
.
join
()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if
self
.
_recovery
:
self
.
_recovery
.
maybe_recover
(
self
.
train_loss
.
result
().
numpy
(),
self
.
global_step
.
numpy
())
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment