Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
856622d3
Commit
856622d3
authored
May 17, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 17, 2021
Browse files
Internal change
PiperOrigin-RevId: 374244811
parent
fb9f35c8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
28 deletions
+72
-28
official/core/base_trainer.py
official/core/base_trainer.py
+40
-20
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+32
-8
No files found.
official/core/base_trainer.py
View file @
856622d3
...
@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
...
@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
interchangable and independent on model architectures and tasks.
interchangable and independent on model architectures and tasks.
"""
"""
import
functools
import
functools
from
typing
import
Union
,
Optional
from
absl
import
logging
from
absl
import
logging
import
gin
import
gin
import
orbit
import
orbit
...
@@ -28,7 +29,6 @@ from official.core import base_task
...
@@ -28,7 +29,6 @@ from official.core import base_task
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.modeling
import
optimization
from
official.modeling
import
optimization
ExperimentConfig
=
config_definitions
.
ExperimentConfig
ExperimentConfig
=
config_definitions
.
ExperimentConfig
TrainerConfig
=
config_definitions
.
TrainerConfig
TrainerConfig
=
config_definitions
.
TrainerConfig
...
@@ -143,6 +143,7 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -143,6 +143,7 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
`tf.distribute.InputContext` instance.
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
Returns:
A distributed Dataset.
A distributed Dataset.
"""
"""
...
@@ -173,13 +174,18 @@ class Trainer(_AsyncTrainer):
...
@@ -173,13 +174,18 @@ class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models."""
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
# pylint: disable=super-init-not-called
def
__init__
(
self
,
def
__init__
(
self
,
config
:
ExperimentConfig
,
config
:
ExperimentConfig
,
task
:
base_task
.
Task
,
task
:
base_task
.
Task
,
model
:
tf
.
keras
.
Model
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
optimizers
.
Optimizer
,
optimizer
:
tf
.
optimizers
.
Optimizer
,
train
:
bool
=
True
,
train
:
bool
=
True
,
evaluate
:
bool
=
True
,
evaluate
:
bool
=
True
,
train_dataset
:
Optional
[
Union
[
tf
.
data
.
Dataset
,
tf
.
distribute
.
DistributedDataset
]]
=
None
,
validation_dataset
:
Optional
[
Union
[
tf
.
data
.
Dataset
,
tf
.
distribute
.
DistributedDataset
]]
=
None
,
checkpoint_exporter
=
None
):
checkpoint_exporter
=
None
):
"""Initialize common trainer for TensorFlow models.
"""Initialize common trainer for TensorFlow models.
...
@@ -192,13 +198,22 @@ class Trainer(_AsyncTrainer):
...
@@ -192,13 +198,22 @@ class Trainer(_AsyncTrainer):
default to True.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
default to True.
train_dataset: a dataset object created for training. With tf.distribute,
it needs to be a `DistributedDataset`.
validation_dataset: a dataset object created for evaluation. With
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
create a dataset iterator for each eval round, so the dataset does not
need to repeat.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
interface.
"""
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_validate_params
(
config
)
self
.
_validate_params
(
config
,
check_train_data
=
train_dataset
is
None
,
check_validation_data
=
validation_dataset
is
None
)
self
.
_config
=
config
self
.
_config
=
config
self
.
_task
=
task
self
.
_task
=
task
self
.
_model
=
model
self
.
_model
=
model
...
@@ -239,7 +254,7 @@ class Trainer(_AsyncTrainer):
...
@@ -239,7 +254,7 @@ class Trainer(_AsyncTrainer):
self
.
init_async
()
self
.
init_async
()
if
train
:
if
train
:
train_dataset
=
self
.
distribute_dataset
(
train_dataset
=
train_dataset
or
self
.
distribute_dataset
(
self
.
task
.
build_inputs
,
self
.
config
.
task
.
train_data
)
self
.
task
.
build_inputs
,
self
.
config
.
task
.
train_data
)
orbit
.
StandardTrainer
.
__init__
(
orbit
.
StandardTrainer
.
__init__
(
self
,
self
,
...
@@ -250,16 +265,19 @@ class Trainer(_AsyncTrainer):
...
@@ -250,16 +265,19 @@ class Trainer(_AsyncTrainer):
use_tpu_summary_optimization
=
config
.
trainer
.
allow_tpu_summary
))
use_tpu_summary_optimization
=
config
.
trainer
.
allow_tpu_summary
))
if
evaluate
:
if
evaluate
:
e
val_dataset
=
self
.
distribute_dataset
(
val
idation
_dataset
=
validation_dataset
or
self
.
distribute_dataset
(
self
.
task
.
build_inputs
,
self
.
config
.
task
.
validation_data
)
self
.
task
.
build_inputs
,
self
.
config
.
task
.
validation_data
)
orbit
.
StandardEvaluator
.
__init__
(
orbit
.
StandardEvaluator
.
__init__
(
self
,
self
,
e
val_dataset
,
val
idation
_dataset
,
options
=
orbit
.
StandardEvaluatorOptions
(
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_function
=
config
.
trainer
.
eval_tf_function
,
use_tf_function
=
config
.
trainer
.
eval_tf_function
,
use_tf_while_loop
=
config
.
trainer
.
eval_tf_while_loop
))
use_tf_while_loop
=
config
.
trainer
.
eval_tf_while_loop
))
def
_validate_params
(
self
,
config
):
def
_validate_params
(
self
,
config
,
check_train_data
=
True
,
check_validation_data
=
True
):
r
"""Validates if the configuration object passed to the Trainer.
r
"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as:
The experiment configuration should be structured as:
...
@@ -270,6 +288,8 @@ class Trainer(_AsyncTrainer):
...
@@ -270,6 +288,8 @@ class Trainer(_AsyncTrainer):
Args:
Args:
config: a namedtuple, dataclass, ConfigDict, etc.
config: a namedtuple, dataclass, ConfigDict, etc.
check_train_data: whether to check task.train_data field.
check_validation_data: whether to check task.validation_data field.
"""
"""
if
not
hasattr
(
config
,
"trainer"
):
if
not
hasattr
(
config
,
"trainer"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
raise
AttributeError
(
"The trainer requires the configuration contains an"
...
@@ -279,11 +299,11 @@ class Trainer(_AsyncTrainer):
...
@@ -279,11 +299,11 @@ class Trainer(_AsyncTrainer):
raise
AttributeError
(
"The trainer requires the configuration contains an"
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task`."
)
" attribute `task`."
)
if
not
hasattr
(
config
.
task
,
"train_data"
):
if
check_train_data
and
not
hasattr
(
config
.
task
,
"train_data"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task.train_data`."
)
" attribute `task.train_data`."
)
if
not
hasattr
(
config
.
task
,
"validation_data"
):
if
check_validation_data
and
not
hasattr
(
config
.
task
,
"validation_data"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task.validation_data`."
)
" attribute `task.validation_data`."
)
...
@@ -406,8 +426,8 @@ class Trainer(_AsyncTrainer):
...
@@ -406,8 +426,8 @@ class Trainer(_AsyncTrainer):
for
metric
in
self
.
validation_metrics
+
[
self
.
validation_loss
]:
for
metric
in
self
.
validation_metrics
+
[
self
.
validation_loss
]:
metric
.
reset_states
()
metric
.
reset_states
()
# Swaps weights to test on weights moving average.
# Swaps weights to test on weights moving average.
if
self
.
optimizer
and
isinstance
(
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
optimization
.
ExponentialMovingAverage
):
self
.
optimizer
.
swap_weights
()
self
.
optimizer
.
swap_weights
()
def
eval_step
(
self
,
iterator
):
def
eval_step
(
self
,
iterator
):
...
@@ -451,8 +471,8 @@ class Trainer(_AsyncTrainer):
...
@@ -451,8 +471,8 @@ class Trainer(_AsyncTrainer):
# Swaps back weights after testing when EMA is used.
# Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for
# This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights.
# eval are exported instead of regular weights.
if
self
.
optimizer
and
isinstance
(
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
optimization
.
ExponentialMovingAverage
):
self
.
optimizer
.
swap_weights
()
self
.
optimizer
.
swap_weights
()
return
logs
return
logs
...
...
official/core/base_trainer_test.py
View file @
856622d3
...
@@ -20,6 +20,7 @@ import sys
...
@@ -20,6 +20,7 @@ import sys
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
orbit
import
portpicker
import
portpicker
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -111,15 +112,14 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
...
@@ -111,15 +112,14 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
train_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
train_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
trainer_lib
.
orbit
.
StandardTrainer
.
__init__
(
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
trainer_lib
.
orbit
.
StandardTrainerOptions
())
self
,
train_dataset
,
options
=
orbit
.
StandardTrainerOptions
())
e
val_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
val
idation
_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
trainer_lib
.
orbit
.
StandardEvaluator
.
__init__
(
orbit
.
StandardEvaluator
.
__init__
(
self
,
self
,
eval_dataset
,
validation_dataset
,
options
=
trainer_lib
.
orbit
.
StandardEvaluatorOptions
(
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_while_loop
=
True
))
use_tf_while_loop
=
True
))
def
train_loop_begin
(
self
):
def
train_loop_begin
(
self
):
self
.
global_step
.
assign
(
0
)
self
.
global_step
.
assign
(
0
)
...
@@ -185,6 +185,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -185,6 +185,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_passing_datasets
(
self
,
distribution
):
with
distribution
.
scope
():
task
=
mock_task
.
MockTask
(
self
.
_config
)
train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
distribution
,
task
.
build_inputs
,
self
.
_config
.
task
.
train_data
)
validation_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
distribution
,
task
.
build_inputs
,
self
.
_config
.
task
.
validation_data
)
self
.
_config
.
task
.
train_data
=
None
self
.
_config
.
task
.
validation_data
=
None
trainer
=
trainer_lib
.
Trainer
(
self
.
_config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
task
.
create_optimizer
(
self
.
_config
.
trainer
.
optimizer_config
,
self
.
_config
.
runtime
),
train_dataset
=
train_dataset
,
validation_dataset
=
validation_dataset
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'validation_loss'
,
logs
)
def
test_base_async_trainer
(
self
):
def
test_base_async_trainer
(
self
):
if
TPU_TEST
or
GPU_TEST
:
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
...
@@ -204,7 +228,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -204,7 +228,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def
test_async_trainer_train
(
self
):
def
test_async_trainer_train
(
self
):
if
TPU_TEST
or
GPU_TEST
:
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/
G
PU.'
)
self
.
skipTest
(
'Aysnc training is not available on GPU/
T
PU.'
)
num_workers
=
3
num_workers
=
3
num_ps
=
2
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment