Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6dbdb08c
Commit
6dbdb08c
authored
Jul 01, 2022
by
Yeqing Li
Committed by
A. Unique TensorFlower
Jul 01, 2022
Browse files
Refactors the run_experiment function for better reusability.
PiperOrigin-RevId: 458550388
parent
f1add1bc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
292 additions
and
91 deletions
+292
-91
official/core/train_lib.py
official/core/train_lib.py
+233
-87
official/core/train_lib_test.py
official/core/train_lib_test.py
+59
-4
No files found.
official/core/train_lib.py
View file @
6dbdb08c
...
...
@@ -32,6 +32,226 @@ from official.core import train_utils
maybe_create_best_ckpt_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
class
OrbitExperimentRunner
:
"""Runs experiment with Orbit training loop.
The default experiment runner for model garden experiments. User can
customize the experiment pipeline by subclassing this class and replacing
components or functions.
For example, an experiment runner with customized checkpoint manager:
```python
class MyExpRunnerWithExporter(AbstractExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
return MyCheckpointManager(*args)
# In user code
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
```
Similar override can be done to other components.
"""
def
__init__
(
self
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
,
controller_cls
=
orbit
.
Controller
):
"""Constructor.
Args:
distribution_strategy: A distribution strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval',
'train_and_eval' or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within
the strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
"""
self
.
strategy
=
distribution_strategy
or
tf
.
distribute
.
get_strategy
()
self
.
_params
=
params
self
.
_model_dir
=
model_dir
self
.
_mode
=
mode
self
.
_run_post_eval
=
run_post_eval
self
.
_trainer
=
trainer
or
self
.
_build_trainer
(
task
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
)
assert
self
.
trainer
is
not
None
self
.
_checkpoint_manager
=
self
.
_maybe_build_checkpoint_manager
()
self
.
_controller
=
self
.
_build_controller
(
trainer
=
self
.
trainer
if
'train'
in
mode
else
None
,
evaluator
=
self
.
trainer
,
save_summary
=
save_summary
,
train_actions
=
train_actions
,
eval_actions
=
eval_actions
,
controller_cls
=
controller_cls
)
@
property
def
params
(
self
)
->
config_definitions
.
ExperimentConfig
:
return
self
.
_params
@
property
def
model_dir
(
self
)
->
str
:
return
self
.
_model_dir
@
property
def
trainer
(
self
)
->
base_trainer
.
Trainer
:
return
self
.
_trainer
@
property
def
checkpoint_manager
(
self
)
->
tf
.
train
.
CheckpointManager
:
return
self
.
_checkpoint_manager
@
property
def
controller
(
self
)
->
orbit
.
Controller
:
return
self
.
_controller
def
_build_trainer
(
self
,
task
:
base_task
.
Task
,
train
:
bool
,
evaluate
:
bool
)
->
base_trainer
.
Trainer
:
"""Create trainer."""
with
self
.
strategy
.
scope
():
trainer
=
train_utils
.
create_trainer
(
self
.
params
,
task
,
train
=
train
,
evaluate
=
evaluate
,
checkpoint_exporter
=
self
.
_build_best_checkpoint_exporter
())
return
trainer
def
_build_best_checkpoint_exporter
(
self
):
return
maybe_create_best_ckpt_exporter
(
self
.
params
,
self
.
model_dir
)
def
_maybe_build_checkpoint_manager
(
self
)
->
Optional
[
tf
.
train
.
CheckpointManager
]:
"""Maybe create a CheckpointManager."""
assert
self
.
trainer
is
not
None
if
self
.
trainer
.
checkpoint
:
if
self
.
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
self
.
trainer
.
checkpoint
,
directory
=
self
.
model_dir
,
max_to_keep
=
self
.
params
.
trainer
.
max_to_keep
,
step_counter
=
self
.
trainer
.
global_step
,
checkpoint_interval
=
self
.
params
.
trainer
.
checkpoint_interval
,
init_fn
=
self
.
trainer
.
initialize
)
else
:
checkpoint_manager
=
None
return
checkpoint_manager
def
_build_controller
(
self
,
trainer
,
evaluator
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
orbit
.
Controller
:
"""Builds a Orbit controler."""
train_actions
=
[]
if
not
train_actions
else
train_actions
if
trainer
:
train_actions
+=
actions
.
get_train_actions
(
self
.
params
,
trainer
,
self
.
model_dir
,
checkpoint_manager
=
self
.
checkpoint_manager
)
eval_actions
=
[]
if
not
eval_actions
else
eval_actions
if
evaluator
:
eval_actions
+=
actions
.
get_eval_actions
(
self
.
params
,
evaluator
,
self
.
model_dir
)
controller
=
controller_cls
(
strategy
=
self
.
strategy
,
trainer
=
trainer
,
evaluator
=
evaluator
,
global_step
=
self
.
trainer
.
global_step
,
steps_per_loop
=
self
.
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
self
.
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
self
.
params
.
trainer
.
validation_summary_subdir
)
if
(
save_summary
)
else
None
,
summary_interval
=
self
.
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
train_actions
=
train_actions
,
eval_actions
=
eval_actions
)
return
controller
def
run
(
self
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Run experiments by mode.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
mode
=
self
.
_mode
params
=
self
.
params
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
self
.
strategy
.
scope
():
if
mode
==
'train'
or
mode
==
'train_and_post_eval'
:
self
.
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
self
.
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
self
.
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
self
.
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
self
.
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
num_params
=
train_utils
.
try_count_params
(
self
.
trainer
.
model
)
if
num_params
is
not
None
:
logging
.
info
(
'Number of trainable params in model: %f Millions.'
,
num_params
/
10.
**
6
)
flops
=
train_utils
.
try_count_flops
(
self
.
trainer
.
model
)
if
flops
is
not
None
:
logging
.
info
(
'FLOPs (multi-adds) in model: %f Billions.'
,
flops
/
10.
**
9
/
2
)
if
self
.
_run_post_eval
or
mode
==
'train_and_post_eval'
:
with
self
.
strategy
.
scope
():
return
self
.
trainer
.
model
,
self
.
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
else
:
return
self
.
trainer
.
model
,
{}
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
...
...
@@ -70,91 +290,17 @@ def run_experiment(
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with
distribution_strategy
.
scope
():
if
not
trainer
:
trainer
=
train_utils
.
create_trainer
(
params
,
task
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
checkpoint_exporter
=
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
if
trainer
.
checkpoint
:
if
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
else
:
checkpoint_manager
=
None
train_actions
=
[]
if
not
train_actions
else
train_actions
train_actions
+=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
,
checkpoint_manager
=
checkpoint_manager
)
eval_actions
=
[]
if
not
eval_actions
else
eval_actions
eval_actions
+=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
)
controller
=
controller_cls
(
strategy
=
distribution_strategy
,
trainer
=
trainer
if
'train'
in
mode
else
None
,
evaluator
=
trainer
,
global_step
=
trainer
.
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
params
.
trainer
.
validation_summary_subdir
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
runner
=
OrbitExperimentRunner
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
mode
,
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
,
save_summary
=
save_summary
,
train_actions
=
train_actions
,
eval_actions
=
eval_actions
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
or
mode
==
'train_and_post_eval'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
num_params
=
train_utils
.
try_count_params
(
trainer
.
model
)
if
num_params
is
not
None
:
logging
.
info
(
'Number of trainable params in model: %f Millions.'
,
num_params
/
10.
**
6
)
flops
=
train_utils
.
try_count_flops
(
trainer
.
model
)
if
flops
is
not
None
:
logging
.
info
(
'FLOPs (multi-adds) in model: %f Billions.'
,
flops
/
10.
**
9
/
2
)
if
run_post_eval
or
mode
==
'train_and_post_eval'
:
with
distribution_strategy
.
scope
():
return
trainer
.
model
,
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
else
:
return
trainer
.
model
,
{}
eval_actions
=
eval_actions
,
trainer
=
trainer
,
controller_cls
=
controller_cls
,
)
return
runner
.
run
()
official/core/train_lib_test.py
View file @
6dbdb08c
...
...
@@ -117,6 +117,61 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
],
run_post_eval
=
[
True
,
False
]))
def
test_end_to_end_class
(
self
,
distribution_strategy
,
flag_mode
,
run_post_eval
):
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
_
,
logs
=
train_lib
.
OrbitExperimentRunner
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
).
run
()
if
'eval'
in
flag_mode
:
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
params
.
trainer
.
validation_summary_subdir
)))
if
run_post_eval
:
self
.
assertNotEmpty
(
logs
)
else
:
self
.
assertEmpty
(
logs
)
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'params.yaml'
)))
if
flag_mode
==
'eval'
:
return
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'checkpoint'
)))
# Tests continuous evaluation.
_
,
logs
=
train_lib
.
OrbitExperimentRunner
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
'continuous_eval'
,
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
).
run
()
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
...
...
@@ -148,12 +203,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task
.
build_losses
=
build_losses
with
self
.
assertRaises
(
RuntimeError
):
train_lib
.
run_e
xperiment
(
train_lib
.
OrbitE
xperiment
Runner
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
model_dir
=
model_dir
)
.
run
()
@
combinations
.
generate
(
combinations
.
combine
(
...
...
@@ -194,12 +249,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task
.
build_losses
=
build_losses
model
,
_
=
train_lib
.
run_e
xperiment
(
model
,
_
=
train_lib
.
OrbitE
xperiment
Runner
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
model_dir
=
model_dir
)
.
run
()
after_weights
=
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment