Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
08189186
Commit
08189186
authored
Aug 05, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 389019529
parent
d088f0d5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
196 additions
and
82 deletions
+196
-82
official/core/actions.py
official/core/actions.py
+69
-3
official/core/actions_test.py
official/core/actions_test.py
+35
-9
official/core/base_trainer.py
official/core/base_trainer.py
+1
-5
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+0
-56
official/core/train_lib.py
official/core/train_lib.py
+2
-3
official/core/train_lib_test.py
official/core/train_lib_test.py
+89
-1
official/recommendation/ranking/train.py
official/recommendation/ranking/train.py
+0
-5
No files found.
official/core/actions.py
View file @
08189186
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
os
import
os
from
typing
import
List
from
typing
import
List
from
absl
import
logging
import
gin
import
gin
import
orbit
import
orbit
...
@@ -119,6 +120,58 @@ class EMACheckpointing:
...
@@ -119,6 +120,58 @@ class EMACheckpointing:
self
.
_optimizer
.
swap_weights
()
self
.
_optimizer
.
swap_weights
()
class
RecoveryAction
:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def
__init__
(
self
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
self
.
checkpoint_manager
=
checkpoint_manager
def
__call__
(
self
,
_
):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path
=
self
.
checkpoint_manager
.
restore_or_initialize
()
logging
.
warning
(
'Recovering the model from checkpoint: %s.'
,
checkpoint_path
)
class
RecoveryCondition
:
"""Recovery Condition."""
def
__init__
(
self
,
global_step
:
tf
.
Variable
,
loss_upper_bound
:
float
,
recovery_begin_steps
:
int
=
0
,
recovery_max_trials
:
int
=
3
):
self
.
recover_counter
=
0
self
.
recovery_begin_steps
=
recovery_begin_steps
self
.
recovery_max_trials
=
recovery_max_trials
self
.
loss_upper_bound
=
loss_upper_bound
self
.
global_step
=
global_step
def
__call__
(
self
,
outputs
:
orbit
.
runner
.
Output
):
loss_value
=
outputs
[
'training_loss'
]
if
tf
.
math
.
is_nan
(
loss_value
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
'The loss value is NaN after training loop and it happens %d times.'
%
self
.
recover_counter
)
return
True
if
(
self
.
global_step
>=
self
.
recovery_begin_steps
and
loss_value
>
self
.
loss_upper_bound
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
f
'The loss value is
{
loss_value
}
, which is larger than the bound
{
self
.
loss_upper_bound
}
, happens
{
self
.
recover_counter
}
times.'
)
return
True
return
False
@
gin
.
configurable
@
gin
.
configurable
def
get_eval_actions
(
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
params
:
config_definitions
.
ExperimentConfig
,
...
@@ -140,9 +193,10 @@ def get_eval_actions(
...
@@ -140,9 +193,10 @@ def get_eval_actions(
@
gin
.
configurable
@
gin
.
configurable
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
def
get_train_actions
(
trainer
:
base_trainer
.
Trainer
,
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
model_dir
:
str
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
)
->
List
[
orbit
.
Action
]:
"""Gets train actions for TFM trainer."""
"""Gets train actions for TFM trainer."""
train_actions
=
[]
train_actions
=
[]
# Adds pruning callback actions.
# Adds pruning callback actions.
...
@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig,
...
@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig,
model
=
trainer
.
model
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
optimizer
=
trainer
.
optimizer
))
if
params
.
trainer
.
recovery_max_trials
>=
0
:
recovery_condition
=
RecoveryCondition
(
global_step
=
trainer
.
global_step
,
loss_upper_bound
=
params
.
trainer
.
loss_upper_bound
,
recovery_begin_steps
=
params
.
trainer
.
recovery_begin_steps
,
recovery_max_trials
=
params
.
trainer
.
recovery_max_trials
,
)
recover_action
=
orbit
.
actions
.
ConditionalAction
(
condition
=
recovery_condition
,
action
=
RecoveryAction
(
checkpoint_manager
),
)
train_actions
.
append
(
recover_action
)
return
train_actions
return
train_actions
official/core/actions_test.py
View file @
08189186
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
import
os
import
os
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
...
@@ -35,17 +37,14 @@ class TestModel(tf.Module):
...
@@ -35,17 +37,14 @@ class TestModel(tf.Module):
return
self
.
value
return
self
.
value
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],)
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],))
def
test_ema_checkpointing
(
self
,
distribution
):
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
with
distribution
.
scope
():
directory
=
self
.
create_tempdir
()
directory
=
self
.
create_tempdir
()
...
@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
# Checks model.value is 0 after swapping.
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(),
0
)
self
.
assertEqual
(
model
(),
0
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],))
def
test_recovery_condition
(
self
,
distribution
):
with
distribution
.
scope
():
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
0.6
}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
tf
.
constant
([
np
.
nan
],
tf
.
float32
)}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/core/base_trainer.py
View file @
08189186
...
@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer):
...
@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint."""
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
return
self
.
_checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def
add_recovery
(
self
,
params
:
TrainerConfig
,
def
add_recovery
(
self
,
params
:
TrainerConfig
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
if
params
.
recovery_max_trials
>=
0
:
if
params
.
recovery_max_trials
>=
0
:
...
@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer):
...
@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer):
def
train_loop_end
(
self
):
def
train_loop_end
(
self
):
"""See base class."""
"""See base class."""
self
.
join
()
self
.
join
()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if
self
.
_recovery
:
self
.
_recovery
.
maybe_recover
(
self
.
train_loss
.
result
().
numpy
(),
self
.
global_step
.
numpy
())
logs
=
{}
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
logs
[
metric
.
name
]
=
metric
.
result
()
...
...
official/core/base_trainer_test.py
View file @
08189186
...
@@ -19,7 +19,6 @@ import os
...
@@ -19,7 +19,6 @@ import os
import
sys
import
sys
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
orbit
import
portpicker
import
portpicker
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -337,61 +336,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -337,61 +336,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertTrue
(
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
'best_ckpt'
,
'info.json'
)))
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
'best_ckpt'
,
'info.json'
)))
def
test_recovery
(
self
):
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
model_dir
=
self
.
get_temp_dir
()
trainer
=
self
.
create_test_trainer
(
config
,
model_dir
=
model_dir
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
self
.
get_temp_dir
(),
max_to_keep
=
2
)
checkpoint_manager
.
save
()
trainer
.
add_recovery
(
config
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
before_weights
=
trainer
.
model
.
get_weights
()
_
=
trainer
.
train
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
# The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
after_weights
=
trainer
.
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
# Let's the loss be NaN and max_trials = 0 to see RuntimeError.
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
recovery_max_trials
=
0
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
task
=
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
np
.
nan
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
trainer
=
trainer_lib
.
Trainer
(
config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
task
.
create_optimizer
(
config
.
trainer
.
optimizer_config
,
config
.
runtime
))
trainer
.
add_recovery
(
config
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
with
self
.
assertRaises
(
RuntimeError
):
_
=
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
def
test_model_with_compiled_loss
(
self
):
def
test_model_with_compiled_loss
(
self
):
task
=
mock_task
.
MockTask
()
task
=
mock_task
.
MockTask
()
model
=
task
.
build_model
()
model
=
task
.
build_model
()
...
...
official/core/train_lib.py
View file @
08189186
...
@@ -87,8 +87,6 @@ def run_experiment(
...
@@ -87,8 +87,6 @@ def run_experiment(
step_counter
=
trainer
.
global_step
,
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
init_fn
=
trainer
.
initialize
)
# Adds recovery handling.
trainer
.
add_recovery
(
params
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
else
:
else
:
checkpoint_manager
=
None
checkpoint_manager
=
None
...
@@ -105,7 +103,8 @@ def run_experiment(
...
@@ -105,7 +103,8 @@ def run_experiment(
(
save_summary
)
else
None
,
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
),
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
,
checkpoint_manager
=
checkpoint_manager
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
...
...
official/core/train_lib_test.py
View file @
08189186
...
@@ -19,6 +19,7 @@ import os
...
@@ -19,6 +19,7 @@ import os
from
absl
import
flags
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
...
@@ -30,6 +31,7 @@ from official.common import registry_imports
...
@@ -30,6 +31,7 @@ from official.common import registry_imports
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
params
=
params
,
params
=
params
,
model_dir
=
model_dir
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
run_post_eval
=
run_post_eval
)
print
(
logs
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
,
'train_and_eval'
],
))
def
test_recovery_nan_error
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task
=
mock_task
.
MockTask
(
params
.
task
,
logging_dir
=
model_dir
)
# Set the loss to NaN to trigger RunTimeError.
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
np
.
nan
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
with
self
.
assertRaises
(
RuntimeError
):
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
],
))
def
test_recovery
(
self
,
distribution_strategy
,
flag_mode
):
loss_threshold
=
1.0
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
params
.
trainer
.
loss_upper_bound
=
loss_threshold
params
.
trainer
.
recovery_max_trials
=
1
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
# Saves a checkpoint for reference.
model
=
task
.
build_model
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
get_temp_dir
(),
max_to_keep
=
2
)
checkpoint_manager
.
save
()
before_weights
=
model
.
get_weights
()
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
loss_threshold
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
model
,
_
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
after_weights
=
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
def
test_parse_configuration
(
self
):
def
test_parse_configuration
(
self
):
model_dir
=
self
.
get_temp_dir
()
model_dir
=
self
.
get_temp_dir
()
...
...
official/recommendation/ranking/train.py
View file @
08189186
...
@@ -43,11 +43,6 @@ class RankingTrainer(base_trainer.Trainer):
...
@@ -43,11 +43,6 @@ class RankingTrainer(base_trainer.Trainer):
def
train_loop_end
(
self
)
->
Dict
[
str
,
float
]:
def
train_loop_end
(
self
)
->
Dict
[
str
,
float
]:
"""See base class."""
"""See base class."""
self
.
join
()
self
.
join
()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if
self
.
_recovery
:
self
.
_recovery
.
maybe_recover
(
self
.
train_loss
.
result
().
numpy
(),
self
.
global_step
.
numpy
())
logs
=
{}
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
logs
[
metric
.
name
]
=
metric
.
result
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment