Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
856622d3
Commit
856622d3
authored
May 17, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 17, 2021
Browse files
Internal change
PiperOrigin-RevId: 374244811
parent
fb9f35c8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
28 deletions
+72
-28
official/core/base_trainer.py
official/core/base_trainer.py
+40
-20
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+32
-8
No files found.
official/core/base_trainer.py
View file @
856622d3
...
...
@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
interchangable and independent on model architectures and tasks.
"""
import
functools
from
typing
import
Union
,
Optional
from
absl
import
logging
import
gin
import
orbit
...
...
@@ -28,7 +29,6 @@ from official.core import base_task
from
official.core
import
config_definitions
from
official.modeling
import
optimization
ExperimentConfig
=
config_definitions
.
ExperimentConfig
TrainerConfig
=
config_definitions
.
TrainerConfig
...
...
@@ -143,6 +143,7 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
...
...
@@ -173,13 +174,18 @@ class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def
__init__
(
self
,
def
__init__
(
self
,
config
:
ExperimentConfig
,
task
:
base_task
.
Task
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
optimizers
.
Optimizer
,
train
:
bool
=
True
,
evaluate
:
bool
=
True
,
train_dataset
:
Optional
[
Union
[
tf
.
data
.
Dataset
,
tf
.
distribute
.
DistributedDataset
]]
=
None
,
validation_dataset
:
Optional
[
Union
[
tf
.
data
.
Dataset
,
tf
.
distribute
.
DistributedDataset
]]
=
None
,
checkpoint_exporter
=
None
):
"""Initialize common trainer for TensorFlow models.
...
...
@@ -192,13 +198,22 @@ class Trainer(_AsyncTrainer):
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
train_dataset: a dataset object created for training. With tf.distribute,
it needs to be a `DistributedDataset`.
validation_dataset: a dataset object created for evaluation. With
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
create a dataset iterator for each eval round, so the dataset does not
need to repeat.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_validate_params
(
config
)
self
.
_validate_params
(
config
,
check_train_data
=
train_dataset
is
None
,
check_validation_data
=
validation_dataset
is
None
)
self
.
_config
=
config
self
.
_task
=
task
self
.
_model
=
model
...
...
@@ -239,7 +254,7 @@ class Trainer(_AsyncTrainer):
self
.
init_async
()
if
train
:
train_dataset
=
self
.
distribute_dataset
(
train_dataset
=
train_dataset
or
self
.
distribute_dataset
(
self
.
task
.
build_inputs
,
self
.
config
.
task
.
train_data
)
orbit
.
StandardTrainer
.
__init__
(
self
,
...
...
@@ -250,16 +265,19 @@ class Trainer(_AsyncTrainer):
use_tpu_summary_optimization
=
config
.
trainer
.
allow_tpu_summary
))
if
evaluate
:
e
val_dataset
=
self
.
distribute_dataset
(
val
idation
_dataset
=
validation_dataset
or
self
.
distribute_dataset
(
self
.
task
.
build_inputs
,
self
.
config
.
task
.
validation_data
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
e
val_dataset
,
val
idation
_dataset
,
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_function
=
config
.
trainer
.
eval_tf_function
,
use_tf_while_loop
=
config
.
trainer
.
eval_tf_while_loop
))
def
_validate_params
(
self
,
config
):
def
_validate_params
(
self
,
config
,
check_train_data
=
True
,
check_validation_data
=
True
):
r
"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as:
...
...
@@ -270,6 +288,8 @@ class Trainer(_AsyncTrainer):
Args:
config: a namedtuple, dataclass, ConfigDict, etc.
check_train_data: whether to check task.train_data field.
check_validation_data: whether to check task.validation_data field.
"""
if
not
hasattr
(
config
,
"trainer"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
...
...
@@ -279,11 +299,11 @@ class Trainer(_AsyncTrainer):
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task`."
)
if
not
hasattr
(
config
.
task
,
"train_data"
):
if
check_train_data
and
not
hasattr
(
config
.
task
,
"train_data"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task.train_data`."
)
if
not
hasattr
(
config
.
task
,
"validation_data"
):
if
check_validation_data
and
not
hasattr
(
config
.
task
,
"validation_data"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task.validation_data`."
)
...
...
@@ -406,8 +426,8 @@ class Trainer(_AsyncTrainer):
for
metric
in
self
.
validation_metrics
+
[
self
.
validation_loss
]:
metric
.
reset_states
()
# Swaps weights to test on weights moving average.
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
self
.
optimizer
.
swap_weights
()
def
eval_step
(
self
,
iterator
):
...
...
@@ -451,8 +471,8 @@ class Trainer(_AsyncTrainer):
# Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights.
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
self
.
optimizer
.
swap_weights
()
return
logs
...
...
official/core/base_trainer_test.py
View file @
856622d3
...
...
@@ -20,6 +20,7 @@ import sys
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
portpicker
import
tensorflow
as
tf
...
...
@@ -111,15 +112,14 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
train_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
trainer_lib
.
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
trainer_lib
.
orbit
.
StandardTrainerOptions
())
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
orbit
.
StandardTrainerOptions
())
e
val_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
trainer_lib
.
orbit
.
StandardEvaluator
.
__init__
(
val
idation
_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
,
options
=
trainer_lib
.
orbit
.
StandardEvaluatorOptions
(
use_tf_while_loop
=
True
))
validation_dataset
,
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_while_loop
=
True
))
def
train_loop_begin
(
self
):
self
.
global_step
.
assign
(
0
)
...
...
@@ -185,6 +185,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_passing_datasets
(
self
,
distribution
):
with
distribution
.
scope
():
task
=
mock_task
.
MockTask
(
self
.
_config
)
train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
distribution
,
task
.
build_inputs
,
self
.
_config
.
task
.
train_data
)
validation_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
distribution
,
task
.
build_inputs
,
self
.
_config
.
task
.
validation_data
)
self
.
_config
.
task
.
train_data
=
None
self
.
_config
.
task
.
validation_data
=
None
trainer
=
trainer_lib
.
Trainer
(
self
.
_config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
task
.
create_optimizer
(
self
.
_config
.
trainer
.
optimizer_config
,
self
.
_config
.
runtime
),
train_dataset
=
train_dataset
,
validation_dataset
=
validation_dataset
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'validation_loss'
,
logs
)
def
test_base_async_trainer
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
...
...
@@ -204,7 +228,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def
test_async_trainer_train
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/
G
PU.'
)
self
.
skipTest
(
'Aysnc training is not available on GPU/
T
PU.'
)
num_workers
=
3
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment