Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
775
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1338 additions
and
206 deletions
+1338
-206
official/core/base_trainer.py
official/core/base_trainer.py
+6
-52
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+4
-26
official/core/config_definitions.py
official/core/config_definitions.py
+43
-4
official/core/exp_factory.py
official/core/exp_factory.py
+1
-1
official/core/export_base.py
official/core/export_base.py
+12
-3
official/core/export_base_test.py
official/core/export_base_test.py
+1
-1
official/core/file_writers.py
official/core/file_writers.py
+80
-0
official/core/file_writers_test.py
official/core/file_writers_test.py
+53
-0
official/core/input_reader.py
official/core/input_reader.py
+105
-29
official/core/registry.py
official/core/registry.py
+12
-3
official/core/registry_test.py
official/core/registry_test.py
+1
-1
official/core/savedmodel_checkpoint_manager.py
official/core/savedmodel_checkpoint_manager.py
+244
-0
official/core/savedmodel_checkpoint_manager_test.py
official/core/savedmodel_checkpoint_manager_test.py
+114
-0
official/core/task_factory.py
official/core/task_factory.py
+1
-1
official/core/test_utils.py
official/core/test_utils.py
+1
-1
official/core/tf_example_builder.py
official/core/tf_example_builder.py
+144
-0
official/core/tf_example_builder_test.py
official/core/tf_example_builder_test.py
+165
-0
official/core/tf_example_feature_key.py
official/core/tf_example_feature_key.py
+62
-0
official/core/tf_example_feature_key_test.py
official/core/tf_example_feature_key_test.py
+49
-0
official/core/train_lib.py
official/core/train_lib.py
+240
-84
No files found.
Too many changes to show.
To preserve performance only
775 of 775+
files are displayed.
Plain diff
Email patch
official/core/base_trainer.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -33,57 +33,6 @@ ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig
=
config_definitions
.
TrainerConfig
class
Recovery
:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def
__init__
(
self
,
loss_upper_bound
:
float
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
,
recovery_begin_steps
:
int
=
0
,
recovery_max_trials
:
int
=
3
):
self
.
recover_counter
=
0
self
.
recovery_begin_steps
=
recovery_begin_steps
self
.
recovery_max_trials
=
recovery_max_trials
self
.
loss_upper_bound
=
loss_upper_bound
self
.
checkpoint_manager
=
checkpoint_manager
def
should_recover
(
self
,
loss_value
,
global_step
):
if
tf
.
math
.
is_nan
(
loss_value
):
return
True
if
(
global_step
>=
self
.
recovery_begin_steps
and
loss_value
>
self
.
loss_upper_bound
):
return
True
return
False
def
maybe_recover
(
self
,
loss_value
,
global_step
):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if
not
self
.
should_recover
(
loss_value
,
global_step
):
return
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
"The loss value is NaN or out of range after training loop and "
f
"this happens
{
self
.
recover_counter
}
times."
)
# Loads the previous good checkpoint.
checkpoint_path
=
self
.
checkpoint_manager
.
restore_or_initialize
()
logging
.
warning
(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d."
,
checkpoint_path
,
loss_value
,
global_step
)
class
_AsyncTrainer
(
orbit
.
StandardTrainer
,
orbit
.
StandardEvaluator
):
"""Trainer class for both sync and async Strategy."""
...
...
@@ -370,6 +319,11 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
@
property
def
checkpoint_exporter
(
self
):
"""Accesses the checkpoint exporter."""
return
self
.
_checkpoint_exporter
def
train_loop_end
(
self
):
"""See base class."""
self
.
join
()
...
...
official/core/base_trainer_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -150,30 +150,6 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
return
self
.
eval_global_step
.
numpy
()
class
RecoveryTest
(
tf
.
test
.
TestCase
):
def
test_recovery_module
(
self
):
ckpt
=
tf
.
train
.
Checkpoint
(
v
=
tf
.
Variable
(
1
,
dtype
=
tf
.
int32
))
model_dir
=
self
.
get_temp_dir
()
manager
=
tf
.
train
.
CheckpointManager
(
ckpt
,
model_dir
,
max_to_keep
=
1
)
recovery_module
=
trainer_lib
.
Recovery
(
loss_upper_bound
=
1.0
,
checkpoint_manager
=
manager
,
recovery_begin_steps
=
1
,
recovery_max_trials
=
1
)
self
.
assertFalse
(
recovery_module
.
should_recover
(
1.1
,
0
))
self
.
assertFalse
(
recovery_module
.
should_recover
(
0.1
,
1
))
self
.
assertTrue
(
recovery_module
.
should_recover
(
1.1
,
2
))
# First triggers the recovery once.
recovery_module
.
maybe_recover
(
1.1
,
10
)
# Second time, it raises.
with
self
.
assertRaisesRegex
(
RuntimeError
,
'The loss value is NaN .*'
):
recovery_module
.
maybe_recover
(
1.1
,
10
)
class
TrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
...
...
@@ -343,7 +319,9 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertFalse
(
trainer
.
optimizer
.
dynamic
)
self
.
assertEqual
(
trainer
.
optimizer
.
initial_scale
,
loss_scale
)
else
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
self
.
assertIsInstance
(
trainer
.
optimizer
,
(
tf
.
keras
.
optimizers
.
SGD
,
tf
.
keras
.
optimizers
.
legacy
.
SGD
))
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
...
...
official/core/config_definitions.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,6 +19,7 @@ from typing import Optional, Sequence, Union
from
official.modeling.hyperparams
import
base_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.modeling.privacy
import
configs
as
dp_configs
OptimizationConfig
=
optimization_config
.
OptimizationConfig
...
...
@@ -74,7 +75,35 @@ class DataConfig(base_config.Config):
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
enable_shared_tf_data_service_between_parallel_trainers: A bool. When set to
true, only a single tf.data service will be started, and it will be shared
between all the trainer run simultaneously, e.g. using vizier to tune
hyperparameters. This will save CPU and RAM resources compared to running
separate tf.data service for each trainer. Notice that if batch size is
different for different trainers, the field
apply_tf_data_service_before_batching also needs to be true so that only a
single tf.data service instance will be created. In this case, tf.data
service will be applied before batching operation. So make sure to not
apply any processing steps after batching (e.g. in postprocess_fn) since
they wouldn't be paralleled by tf.data service and may slow down your
tf.data pipeline. When using shared tf.data service, the tf.data dataset
must be infinite, and slow trainer may skip certain training examples.
More details about shared tf.data service can be found at:
https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
apply_tf_data_service_before_batching: A bool. If set to True, tf.data
service will be applied before batching operation. This is useful to make
sure only a single tf.data service instance is created when
enable_shared_tf_data_service_between_parallel_trainers is true and batch
size is changing between parallel trainers.
trainer_id: A string. The id of the trainer if there are multiple parallel
trainer running at the same time, e.g. in vizier tuning case. It will be
automatically set if this field is needed. Users does not need to set it
when creating experiment configs.
seed: An optional seed to use for deterministic shuffling/preprocessing.
prefetch_buffer_size: An int specifying the buffer size of prefetch
datasets. If None, the buffer size is autotuned. Specifying this is useful
in case autotuning uses up too much memory by making the buffer size too
high.
"""
input_path
:
Union
[
Sequence
[
str
],
str
,
base_config
.
Config
]
=
""
tfds_name
:
str
=
""
...
...
@@ -94,7 +123,11 @@ class DataConfig(base_config.Config):
tfds_data_dir
:
str
=
""
tfds_as_supervised
:
bool
=
False
tfds_skip_decoding_feature
:
str
=
""
enable_shared_tf_data_service_between_parallel_trainers
:
bool
=
False
apply_tf_data_service_before_batching
:
bool
=
False
trainer_id
:
Optional
[
str
]
=
None
seed
:
Optional
[
int
]
=
None
prefetch_buffer_size
:
Optional
[
int
]
=
None
@
dataclasses
.
dataclass
...
...
@@ -189,8 +222,8 @@ class TrainerConfig(base_config.Config):
is only used continuous_train_and_eval and continuous_eval modes. Default
value is 1 hrs.
train_steps: number of train steps.
validation_steps: number of eval steps. If
`None`
, the entire eval dataset
is
used.
validation_steps: number of eval steps. If
-1
, the entire eval dataset
is
used.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
...
...
@@ -240,11 +273,17 @@ class TrainerConfig(base_config.Config):
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
"""Config passed to task."""
init_checkpoint
:
str
=
""
model
:
Optional
[
base_config
.
Config
]
=
None
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
name
:
Optional
[
str
]
=
None
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config
:
Optional
[
dp_configs
.
DifferentialPrivacyConfig
]
=
None
@
dataclasses
.
dataclass
...
...
official/core/exp_factory.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/core/export_base.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -67,6 +67,15 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
if
inference_step
is
not
None
:
self
.
inference_step
=
functools
.
partial
(
inference_step
,
model
=
self
.
model
)
else
:
if
issubclass
(
type
(
model
),
tf
.
keras
.
Model
):
# Default to self.model.call instead of self.model.__call__ to avoid
# keras tracing logic designed for training.
# Since most of Model Garden's call doesn't not have training kwargs
# or the default is False, we don't pass anything here.
# Please pass custom inference step if your model has training=True as
# default.
self
.
inference_step
=
self
.
model
.
call
else
:
self
.
inference_step
=
functools
.
partial
(
self
.
model
.
__call__
,
training
=
False
)
...
...
official/core/export_base_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/core/file_writers.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""File writer functions for dataset preparation, infra validation, and unit tests."""
import
io
from
typing
import
Optional
,
Sequence
,
Union
import
tensorflow
as
tf
def
write_small_dataset
(
examples
:
Sequence
[
Union
[
tf
.
train
.
Example
,
tf
.
train
.
SequenceExample
]],
output_path
:
str
,
file_type
:
str
=
'tfrecord'
)
->
None
:
"""Writes `examples` to a file at `output_path` with type `file_type`.
CAVEAT: This function is not recommended for writing large datasets, since it
will loop through `examples` and perform write operation sequentially.
Args:
examples: List of tf.train.Example or tf.train.SequenceExample.
output_path: Output path for the dataset.
file_type: A string indicating the file format, could be: 'tfrecord',
'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
string is case insensitive.
"""
file_type
=
file_type
.
lower
()
if
file_type
==
'tfrecord'
or
file_type
==
'tfrecords'
:
_write_tfrecord
(
examples
,
output_path
)
elif
file_type
==
'tfrecord_compressed'
or
file_type
==
'tfrecords_gzip'
:
_write_tfrecord
(
examples
,
output_path
,
tf
.
io
.
TFRecordOptions
(
compression_type
=
'GZIP'
))
elif
file_type
==
'riegeli'
:
_write_riegeli
(
examples
,
output_path
)
else
:
raise
ValueError
(
f
'Unknown file_type:
{
file_type
}
'
)
def
_write_tfrecord
(
examples
:
Sequence
[
Union
[
tf
.
train
.
Example
,
tf
.
train
.
SequenceExample
]],
output_path
:
str
,
options
:
Optional
[
tf
.
io
.
TFRecordOptions
]
=
None
)
->
None
:
"""Writes `examples` to a TFRecord file at `output_path`.
Args:
examples: A list of tf.train.Example.
output_path: Output path for the dataset.
options: Options used for manipulating TFRecord files.
"""
with
tf
.
io
.
TFRecordWriter
(
output_path
,
options
)
as
writer
:
for
example
in
examples
:
writer
.
write
(
example
.
SerializeToString
())
def
_write_riegeli
(
examples
:
Sequence
[
Union
[
tf
.
train
.
Example
,
tf
.
train
.
SequenceExample
]],
output_path
:
str
)
->
None
:
"""Writes `examples` to a Riegeli file at `output_path`.
Args:
examples: A list of tf.train.Example.
output_path: Output path for the dataset.
"""
with
io
.
FileIO
(
output_path
,
'wb'
)
as
fileio
:
import
riegeli
# pylint: disable=g-import-not-at-top
with
riegeli
.
RecordWriter
(
fileio
)
as
writer
:
writer
.
write_messages
(
examples
)
official/core/file_writers_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for file_writers."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
file_writers
from
official.core
import
tf_example_builder
class
FileWritersTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_bytes_feature
(
'foo'
,
'Hello World!'
)
self
.
_example
=
example_builder
.
example
@
parameterized
.
parameters
(
'tfrecord'
,
'TFRecord'
,
'tfrecords'
,
'tfrecord_compressed'
,
'TFRecord_Compressed'
,
'tfrecords_gzip'
)
def
test_write_small_dataset_success
(
self
,
file_type
):
temp_dir
=
self
.
create_tempdir
()
temp_dataset_file
=
os
.
path
.
join
(
temp_dir
.
full_path
,
'train'
)
file_writers
.
write_small_dataset
([
self
.
_example
],
temp_dataset_file
,
file_type
)
self
.
assertTrue
(
os
.
path
.
exists
(
temp_dataset_file
))
def
test_write_small_dataset_unrecognized_format
(
self
):
file_type
=
'bar'
temp_dir
=
self
.
create_tempdir
()
temp_dataset_file
=
os
.
path
.
join
(
temp_dir
.
full_path
,
'train'
)
with
self
.
assertRaises
(
ValueError
):
file_writers
.
write_small_dataset
([
self
.
_example
],
temp_dataset_file
,
file_type
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/core/input_reader.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -160,16 +160,38 @@ def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
"""Reads a dataset from tfds."""
# No op if exist.
tfds_builder
.
download_and_prepare
()
decoders
=
{}
if
tfds_skip_decoding_feature
:
for
skip_feature
in
tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
if
tfds_builder
.
info
.
splits
:
num_shards
=
len
(
tfds_builder
.
info
.
splits
[
tfds_split
].
file_instructions
)
else
:
# The tfds mock path often does not provide splits.
num_shards
=
1
if
input_context
and
num_shards
<
input_context
.
num_input_pipelines
:
# The number of files in the dataset split is smaller than the number of
# input pipelines. We read the entire dataset first and then shard in the
# host memory.
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
cycle_length
,
interleave_block_length
=
block_length
,
input_context
=
None
,
shuffle_seed
=
seed
)
dataset
=
tfds_builder
.
as_dataset
(
split
=
tfds_split
,
shuffle_files
=
is_training
,
as_supervised
=
tfds_as_supervised
,
decoders
=
decoders
,
read_config
=
read_config
)
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
else
:
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
cycle_length
,
interleave_block_length
=
block_length
,
input_context
=
input_context
,
shuffle_seed
=
seed
)
decoders
=
{}
if
tfds_skip_decoding_feature
:
for
skip_feature
in
tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
dataset
=
tfds_builder
.
as_dataset
(
split
=
tfds_split
,
shuffle_files
=
is_training
,
...
...
@@ -270,6 +292,8 @@ class InputReader:
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_seed
=
params
.
seed
self
.
_prefetch_buffer_size
=
(
params
.
prefetch_buffer_size
or
tf
.
data
.
experimental
.
AUTOTUNE
)
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
...
...
@@ -282,13 +306,36 @@ class InputReader:
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
self
.
_tf_data_service_address
=
params
.
tf_data_service_address
self
.
_enable_shared_tf_data_service_between_parallel_trainers
=
(
params
.
enable_shared_tf_data_service_between_parallel_trainers
)
self
.
_apply_tf_data_service_before_batching
=
(
params
.
apply_tf_data_service_before_batching
)
self
.
_trainer_id
=
params
.
trainer_id
if
self
.
_enable_tf_data_service
:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
# It's necessary to add global batch size into the tf data service job
# name because when tuning batch size with vizier and tf data service is
# also enable, the tf data servce job name should be different for
# different vizier trials since once batch size is changed, from the
# tf.data perspective, the dataset is a different instance, and a
# different job name should be used for tf data service. Otherwise, the
# model would read tensors from the incorrect tf data service job, which
# would causes dimension mismatch on the batch size dimension.
self
.
_tf_data_service_job_name
=
(
params
.
tf_data_service_job_name
+
str
(
self
.
static_randnum
))
f
'
{
params
.
tf_data_service_job_name
}
_bs
{
params
.
global_batch_size
}
_'
f
'
{
self
.
static_randnum
}
'
)
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
if
self
.
_enable_shared_tf_data_service_between_parallel_trainers
:
# When shared tf.data service is enabled, only a single tf.data service
# instance should be created and shared between parallel trainers. If
# the global batch size is different across trainers,
# params.apply_tf_data_service_before_batching should be set to true
# because tf.data service with different batch sizes will be considered
# separate tf.data service instances.
self
.
_tf_data_service_job_name
=
(
f
'
{
params
.
tf_data_service_job_name
}
_
{
self
.
static_randnum
}
'
)
@
property
def
tfds_info
(
self
)
->
tfds
.
core
.
DatasetInfo
:
...
...
@@ -411,6 +458,19 @@ class InputReader:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
# Applies tf.data service before batching operations. This is useful when
# tf.data service is shared between parallel trainers, and batch size is
# changing between parallel trainers. Then batch size is changing, tf.data
# services will be considered different instances if applied after batching
# operations, which make it difficult to share between parallel trainers.
# However, if there are additional expensive operations in
# self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data
# pipeline could be slowed down. In this case, try to move these dataset
# operations into early stages if possible.
if
(
self
.
_enable_shared_tf_data_service_between_parallel_trainers
and
self
.
_apply_tf_data_service_before_batching
):
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
if
self
.
_transform_and_batch_fn
is
not
None
:
dataset
=
self
.
_transform_and_batch_fn
(
dataset
,
input_context
)
else
:
...
...
@@ -436,13 +496,18 @@ class InputReader:
num_consumers
=
input_context
.
num_input_pipelines
*
(
replicas_per_input_pipeline
)
range_dataset
=
tf
.
data
.
Dataset
.
range
(
replicas_per_input_pipeline
)
tfds_kwargs
=
{
'processing_mode'
:
'parallel_epochs'
,
'service'
:
self
.
_tf_data_service_address
,
'job_name'
:
self
.
_tf_data_service_job_name
,
'num_consumers'
:
num_consumers
}
if
self
.
_enable_shared_tf_data_service_between_parallel_trainers
:
raise
ValueError
(
'Shared tf.data service does not support round-robin'
' tf.data service.'
)
dataset
=
range_dataset
.
map
(
lambda
i
:
dataset
.
apply
(
# pylint: disable=g-long-lambda
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
,
consumer_index
=
base_consumer_index
+
i
,
num_consumers
=
num_consumers
)))
consumer_index
=
base_consumer_index
+
i
,
**
tfds_kwargs
)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset
=
dataset
.
interleave
(
...
...
@@ -451,11 +516,21 @@ class InputReader:
num_parallel_calls
=
replicas_per_input_pipeline
,
deterministic
=
True
)
else
:
tfds_kwargs
=
{
'processing_mode'
:
'parallel_epochs'
,
'service'
:
self
.
_tf_data_service_address
,
'job_name'
:
self
.
_tf_data_service_job_name
,
}
if
self
.
_enable_shared_tf_data_service_between_parallel_trainers
:
tfds_kwargs
.
update
({
'processing_mode'
:
tf
.
data
.
experimental
.
service
.
ShardingPolicy
.
OFF
,
'cross_trainer_cache'
:
tf
.
data
.
experimental
.
service
.
CrossTrainerCache
(
trainer_id
=
self
.
_trainer_id
)
})
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
))
tf
.
data
.
experimental
.
service
.
distribute
(
**
tfds_kwargs
))
return
dataset
def
read
(
self
,
...
...
@@ -463,16 +538,17 @@ class InputReader:
dataset
:
Optional
[
tf
.
data
.
Dataset
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
if
dataset
is
None
:
dataset
=
self
.
_read_data_source
(
self
.
_matched_files
,
self
.
_dataset_fn
,
input_context
,
self
.
_tfds_builder
)
dataset
=
self
.
_read_data_source
(
self
.
_matched_files
,
self
.
_dataset_fn
,
input_context
,
self
.
_tfds_builder
)
dataset
=
self
.
_decode_and_parse_dataset
(
dataset
,
self
.
_global_batch_size
,
input_context
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
if
not
(
self
.
_enable_shared_tf_data_service_between_parallel_trainers
and
self
.
_apply_tf_data_service_before_batching
):
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
if
self
.
_deterministic
is
not
None
:
options
=
tf
.
data
.
Options
()
options
.
experimental_
deterministic
=
self
.
_deterministic
options
.
deterministic
=
self
.
_deterministic
dataset
=
dataset
.
with_options
(
options
)
return
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
self
.
_prefetch_buffer_size
)
official/core/registry.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
"""Registry utility."""
from
absl
import
logging
def
register
(
registered_collection
,
reg_key
):
...
...
@@ -54,6 +55,14 @@ def register(registered_collection, reg_key):
leaf_reg_key
=
reg_key
if
leaf_reg_key
in
collection
:
if
"beta"
in
fn_or_cls
.
__module__
:
# TODO(yeqing): Clean this temporary branch for beta.
logging
.
warn
(
"Duplicate registeration of beta module "
"name %r new %r old %r"
,
reg_key
,
collection
[
leaf_reg_key
],
fn_or_cls
.
__module__
)
return
fn_or_cls
else
:
raise
KeyError
(
"Function or class {} registered multiple times."
.
format
(
leaf_reg_key
))
...
...
official/core/registry_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/core/savedmodel_checkpoint_manager.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom checkpoint manager that also exports saved models."""
import
os
import
re
import
time
from
typing
import
Callable
,
List
,
Mapping
,
Optional
,
Union
from
absl
import
logging
import
tensorflow
as
tf
SAVED_MODULES_PATH_SUFFIX
=
'saved_modules'
def
make_saved_modules_directory_name
(
checkpoint_name
:
str
)
->
str
:
return
f
'
{
checkpoint_name
}
_
{
SAVED_MODULES_PATH_SUFFIX
}
'
class
SavedModelCheckpointManager
(
tf
.
train
.
CheckpointManager
):
"""A CheckpointManager that also exports `SavedModel`s."""
def
__init__
(
self
,
checkpoint
:
tf
.
train
.
Checkpoint
,
directory
:
str
,
max_to_keep
:
int
,
modules_to_export
:
Optional
[
Mapping
[
str
,
tf
.
Module
]]
=
None
,
keep_checkpoint_every_n_hours
:
Optional
[
int
]
=
None
,
checkpoint_name
:
str
=
'ckpt'
,
step_counter
:
Optional
[
tf
.
Variable
]
=
None
,
checkpoint_interval
:
Optional
[
int
]
=
None
,
init_fn
:
Optional
[
Callable
[[],
None
]]
=
None
):
"""See base class."""
super
().
__init__
(
checkpoint
=
checkpoint
,
directory
=
directory
,
max_to_keep
=
max_to_keep
,
keep_checkpoint_every_n_hours
=
keep_checkpoint_every_n_hours
,
checkpoint_name
=
checkpoint_name
,
step_counter
=
step_counter
,
checkpoint_interval
=
checkpoint_interval
,
init_fn
=
init_fn
)
self
.
_modules_to_export
=
modules_to_export
self
.
_savedmodels
=
self
.
get_existing_savedmodels
()
def
save
(
self
,
checkpoint_number
:
Optional
[
int
]
=
None
,
check_interval
:
bool
=
True
,
options
:
Optional
[
tf
.
train
.
CheckpointOptions
]
=
None
):
"""See base class."""
checkpoint_path
=
super
().
save
(
checkpoint_number
=
checkpoint_number
,
check_interval
=
check_interval
,
options
=
options
)
if
not
checkpoint_path
:
# Nothing got written.
return
if
not
self
.
_modules_to_export
:
# No modules to export.
logging
.
info
(
'Skip saving SavedModel due to empty modules_to_export.'
)
return
checkpoint_path
# Save the models for the checkpoint that just got written.
saved_modules_directory
=
make_saved_modules_directory_name
(
checkpoint_path
)
for
model_name
,
model
in
self
.
_modules_to_export
.
items
():
signatures
=
getattr
(
model
,
'saved_model_signatures'
,
None
)
tf
.
saved_model
.
save
(
obj
=
model
,
export_dir
=
os
.
path
.
join
(
saved_modules_directory
,
model_name
),
signatures
=
signatures
)
saved_modules_directories_to_keep
=
[
make_saved_modules_directory_name
(
ckpt
)
for
ckpt
in
self
.
checkpoints
]
existing_saved_modules_dirs
=
self
.
get_existing_savedmodels
()
self
.
_savedmodels
=
[]
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
for
saved_modules_dir_to_keep
in
saved_modules_directories_to_keep
:
if
saved_modules_dir_to_keep
in
existing_saved_modules_dirs
:
self
.
_savedmodels
.
append
(
saved_modules_dir_to_keep
)
for
existing_saved_modules_dir
in
existing_saved_modules_dirs
:
if
existing_saved_modules_dir
not
in
self
.
_savedmodels
:
tf
.
io
.
gfile
.
rmtree
(
existing_saved_modules_dir
)
return
checkpoint_path
def
get_existing_savedmodels
(
self
)
->
List
[
str
]:
"""Gets a list of all existing SavedModel paths in `directory`.
Returns:
A list of all existing SavedModel paths.
"""
saved_modules_glob
=
make_saved_modules_directory_name
(
self
.
_checkpoint_prefix
+
'-*'
)
return
tf
.
io
.
gfile
.
glob
(
saved_modules_glob
)
@
property
def
latest_savedmodel
(
self
)
->
Union
[
str
,
None
]:
"""The path of the most recent SavedModel in `directory`.
Returns:
The latest SavedModel path. If there are no SavedModels, returns `None`.
"""
if
self
.
_savedmodels
:
return
self
.
_savedmodels
[
-
1
]
return
None
@
property
def
savedmodels
(
self
)
->
List
[
str
]:
"""A list of managed SavedModels.
Returns:
A list of SavedModel paths, sorted from oldest to newest.
"""
return
self
.
_savedmodels
@
property
def
modules_to_export
(
self
)
->
Union
[
Mapping
[
str
,
tf
.
Module
],
None
]:
return
self
.
_modules_to_export
def
get_savedmodel_number_from_path
(
self
,
savedmodel_path
:
str
)
->
Union
[
int
,
None
]:
"""Gets the savedmodel_number/checkpoint_number from savedmodel filepath.
The savedmodel_number is global step when using with orbit controller.
Args:
savedmodel_path: savedmodel directory path.
Returns:
Savedmodel number or None if no matched pattern found in savedmodel path.
"""
pattern
=
rf
'\d+_
{
SAVED_MODULES_PATH_SUFFIX
}
$'
savedmodel_number
=
re
.
search
(
pattern
,
savedmodel_path
)
if
savedmodel_number
:
savedmodel_number
=
savedmodel_number
.
group
()
return
int
(
savedmodel_number
[:
-
len
(
SAVED_MODULES_PATH_SUFFIX
)
-
1
])
return
None
def
savedmodels_iterator
(
self
,
min_interval_secs
:
float
=
0
,
timeout
:
Optional
[
float
]
=
None
,
timeout_fn
:
Optional
[
Callable
[[],
bool
]]
=
None
):
"""Continuously yield new SavedModel files as they appear.
The iterator only checks for new savedmodels when control flow has been
reverted to it. The logic is same to the `train.checkpoints_iterator`.
Args:
min_interval_secs: The minimum number of seconds between yielding
savedmodels.
timeout: The maximum number of seconds to wait between savedmodels. If
left as `None`, then the process will wait indefinitely.
timeout_fn: Optional function to call after a timeout. If the function
returns True, then it means that no new savedmodels will be generated
and the iterator will exit. The function is called with no arguments.
Yields:
String paths to latest SavedModel files as they arrive.
"""
savedmodel_path
=
None
while
True
:
new_savedmodel_path
=
self
.
wait_for_new_savedmodel
(
savedmodel_path
,
timeout
=
timeout
)
if
new_savedmodel_path
is
None
:
if
not
timeout_fn
:
# timed out
logging
.
info
(
'Timed-out waiting for a savedmodel.'
)
return
if
timeout_fn
():
# The timeout_fn indicated that we are truly done.
return
else
:
# The timeout_fn indicated that more savedmodels may come.
continue
start
=
time
.
time
()
savedmodel_path
=
new_savedmodel_path
yield
savedmodel_path
time_to_next_eval
=
start
+
min_interval_secs
-
time
.
time
()
if
time_to_next_eval
>
0
:
time
.
sleep
(
time_to_next_eval
)
def
wait_for_new_savedmodel
(
self
,
last_savedmodel
:
Optional
[
str
]
=
None
,
seconds_to_sleep
:
float
=
1.0
,
timeout
:
Optional
[
float
]
=
None
)
->
Union
[
str
,
None
]:
"""Waits until a new savedmodel file is found.
Args:
last_savedmodel: The last savedmodel path used or `None` if we're
expecting a savedmodel for the first time.
seconds_to_sleep: The number of seconds to sleep for before looking for a
new savedmodel.
timeout: The maximum number of seconds to wait. If left as `None`, then
the process will wait indefinitely.
Returns:
A new savedmodel path, or None if the timeout was reached.
"""
logging
.
info
(
'Waiting for new savedmodel at %s'
,
self
.
_directory
)
stop_time
=
time
.
time
()
+
timeout
if
timeout
is
not
None
else
None
last_savedmodel_number
=
0
if
last_savedmodel
:
last_savedmodel_number
=
self
.
get_savedmodel_number_from_path
(
last_savedmodel
)
while
True
:
if
stop_time
is
not
None
and
time
.
time
()
+
seconds_to_sleep
>
stop_time
:
return
None
existing_savedmodels
=
{}
for
savedmodel_path
in
self
.
get_existing_savedmodels
():
savedmodel_number
=
self
.
get_savedmodel_number_from_path
(
savedmodel_path
)
if
savedmodel_number
is
not
None
:
existing_savedmodels
[
savedmodel_number
]
=
savedmodel_path
# Find the first savedmodel with larger step number as next savedmodel.
savedmodel_path
=
None
existing_savedmodels
=
dict
(
sorted
(
existing_savedmodels
.
items
()))
for
savedmodel_number
in
existing_savedmodels
:
if
savedmodel_number
>
last_savedmodel_number
:
savedmodel_path
=
existing_savedmodels
[
savedmodel_number
]
break
if
savedmodel_path
:
logging
.
info
(
'Found new savedmodel at %s'
,
savedmodel_path
)
return
savedmodel_path
else
:
time
.
sleep
(
seconds_to_sleep
)
official/core/savedmodel_checkpoint_manager_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
time
from
typing
import
Iterable
import
tensorflow
as
tf
from
official.core
import
savedmodel_checkpoint_manager
def
_models_exist
(
checkpoint_path
:
str
,
models
:
Iterable
[
str
])
->
bool
:
for
model_name
in
models
:
if
not
tf
.
io
.
gfile
.
isdir
(
os
.
path
.
join
(
savedmodel_checkpoint_manager
.
make_saved_modules_directory_name
(
checkpoint_path
),
model_name
)):
return
False
return
True
class
CheckpointManagerTest
(
tf
.
test
.
TestCase
):
def
_create_manager
(
self
,
max_to_keep
:
int
=
1
)
->
tf
.
train
.
CheckpointManager
:
"""Sets up SavedModelCheckpointManager object.
Args:
max_to_keep: max number of savedmodels to keep.
Returns:
created savedmodel manager.
"""
models
=
{
'model_1'
:
tf
.
keras
.
Sequential
(
layers
=
[
tf
.
keras
.
layers
.
Dense
(
8
,
input_shape
=
(
16
,))]),
'model_2'
:
tf
.
keras
.
Sequential
(
layers
=
[
tf
.
keras
.
layers
.
Dense
(
16
,
input_shape
=
(
32
,))]),
}
checkpoint
=
tf
.
train
.
Checkpoint
()
manager
=
savedmodel_checkpoint_manager
.
SavedModelCheckpointManager
(
checkpoint
=
checkpoint
,
directory
=
self
.
get_temp_dir
(),
max_to_keep
=
max_to_keep
,
modules_to_export
=
models
)
return
manager
def
test_max_to_keep
(
self
):
manager
=
self
.
_create_manager
()
models
=
manager
.
modules_to_export
first_path
=
manager
.
save
()
second_path
=
manager
.
save
()
savedmodel
=
savedmodel_checkpoint_manager
.
make_saved_modules_directory_name
(
manager
.
latest_checkpoint
)
self
.
assertEqual
(
savedmodel
,
manager
.
latest_savedmodel
)
self
.
assertTrue
(
_models_exist
(
second_path
,
models
.
keys
()))
self
.
assertFalse
(
_models_exist
(
first_path
,
models
.
keys
()))
def
test_returns_none_after_timeout
(
self
):
manager
=
self
.
_create_manager
()
start
=
time
.
time
()
ret
=
manager
.
wait_for_new_savedmodel
(
None
,
timeout
=
1.0
,
seconds_to_sleep
=
0.5
)
end
=
time
.
time
()
self
.
assertIsNone
(
ret
)
# We've waited 0.5 second.
self
.
assertGreater
(
end
,
start
+
0.5
)
# The timeout kicked in.
self
.
assertLess
(
end
,
start
+
0.6
)
def
test_saved_model_iterator
(
self
):
manager
=
self
.
_create_manager
(
max_to_keep
=
2
)
self
.
assertIsNotNone
(
manager
.
save
(
checkpoint_number
=
1
))
self
.
assertIsNotNone
(
manager
.
save
(
checkpoint_number
=
2
))
self
.
assertIsNotNone
(
manager
.
save
(
checkpoint_number
=
3
))
# Savedmodels are in time order.
expected_savedmodels
=
manager
.
savedmodels
# Order not guaranteed.
existing_savedmodels
=
manager
.
get_existing_savedmodels
()
savedmodels
=
list
(
manager
.
savedmodels_iterator
(
timeout
=
3.0
))
self
.
assertEqual
(
savedmodels
,
expected_savedmodels
)
self
.
assertEqual
(
set
(
savedmodels
),
set
(
existing_savedmodels
))
def
test_saved_model_iterator_timeout_fn
(
self
):
manager
=
self
.
_create_manager
()
timeout_fn_calls
=
[
0
]
def
timeout_fn
():
timeout_fn_calls
[
0
]
+=
1
return
timeout_fn_calls
[
0
]
>
3
results
=
list
(
manager
.
savedmodels_iterator
(
timeout
=
0.1
,
timeout_fn
=
timeout_fn
))
self
.
assertEqual
([],
results
)
self
.
assertEqual
(
4
,
timeout_fn_calls
[
0
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/core/task_factory.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/core/test_utils.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/core/tf_example_builder.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Builder class for preparing tf.train.Example."""
# https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7
from
__future__
import
annotations
from
typing
import
Mapping
,
Sequence
,
Union
import
numpy
as
np
import
tensorflow
as
tf
BytesValueType
=
Union
[
bytes
,
Sequence
[
bytes
],
str
,
Sequence
[
str
]]
_to_array
=
lambda
v
:
[
v
]
if
not
isinstance
(
v
,
(
list
,
np
.
ndarray
))
else
v
_to_bytes
=
lambda
v
:
v
.
encode
()
if
isinstance
(
v
,
str
)
else
v
_to_bytes_array
=
lambda
v
:
list
(
map
(
_to_bytes
,
_to_array
(
v
)))
class
TfExampleBuilder
(
object
):
"""Builder class for preparing tf.train.Example.
Read API doc at https://www.tensorflow.org/api_docs/python/tf/train/Example.
Example usage:
>>> example_builder = TfExampleBuilder()
>>> example = (
example_builder.add_bytes_feature('feature_a', 'foobarbaz')
.add_ints_feature('feature_b', [1, 2, 3])
.example)
"""
def
__init__
(
self
)
->
None
:
self
.
_example
=
tf
.
train
.
Example
()
@
property
def
example
(
self
)
->
tf
.
train
.
Example
:
"""Returns a copy of the generated tf.train.Example proto."""
return
self
.
_example
@
property
def
serialized_example
(
self
)
->
str
:
"""Returns a serialized string of the generated tf.train.Example proto."""
return
self
.
_example
.
SerializeToString
()
def
set
(
self
,
example
:
tf
.
train
.
Example
)
->
TfExampleBuilder
:
"""Sets the example."""
self
.
_example
=
example
return
self
def
reset
(
self
)
->
TfExampleBuilder
:
"""Resets the example to an empty proto."""
self
.
_example
=
tf
.
train
.
Example
()
return
self
###### Basic APIs for primitive data types ######
def
add_feature_dict
(
self
,
feature_dict
:
Mapping
[
str
,
tf
.
train
.
Feature
])
->
TfExampleBuilder
:
"""Adds the predefined `feature_dict` to the example.
Note: Please prefer to using feature-type-specific methods.
Args:
feature_dict: A dictionary from tf.Example feature key to
tf.train.Feature.
Returns:
The builder object for subsequent method calls.
"""
for
k
,
v
in
feature_dict
.
items
():
self
.
_example
.
features
.
feature
[
k
].
CopyFrom
(
v
)
return
self
def
add_feature
(
self
,
key
:
str
,
feature
:
tf
.
train
.
Feature
)
->
TfExampleBuilder
:
"""Adds predefined `feature` with `key` to the example.
Args:
key: String key of the feature.
feature: The feature to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
self
.
_example
.
features
.
feature
[
key
].
CopyFrom
(
feature
)
return
self
def
add_bytes_feature
(
self
,
key
:
str
,
value
:
BytesValueType
)
->
TfExampleBuilder
:
"""Adds byte(s) or string(s) with `key` to the example.
Args:
key: String key of the feature.
value: The byte(s) or string(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return
self
.
add_feature
(
key
,
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
_to_bytes_array
(
value
))))
def
add_ints_feature
(
self
,
key
:
str
,
value
:
Union
[
int
,
Sequence
[
int
]])
->
TfExampleBuilder
:
"""Adds integer(s) with `key` to the example.
Args:
key: String key of the feature.
value: The integer(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return
self
.
add_feature
(
key
,
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
_to_array
(
value
))))
def
add_floats_feature
(
self
,
key
:
str
,
value
:
Union
[
float
,
Sequence
[
float
]])
->
TfExampleBuilder
:
"""Adds float(s) with `key` to the example.
Args:
key: String key of the feature.
value: The float(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return
self
.
add_feature
(
key
,
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
_to_array
(
value
))))
official/core/tf_example_builder_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_example_builder.
See `test_add_image_matrix_feature_with_fake_image` for the typical structure of
a unit test.
"""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
tf_example_builder
class
TfExampleBuilderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_init_an_empty_example
(
self
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example
=
example_builder
.
example
self
.
assertProtoEquals
(
''
,
example
)
def
test_init_an_empty_serialized_example
(
self
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example
=
example_builder
.
serialized_example
self
.
assertProtoEquals
(
''
,
example
)
def
test_add_feature
(
self
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_feature
(
'foo'
,
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
b
'Hello World!'
])))
example
=
example_builder
.
example
# Use proto text to show how the entire proto would look like.
self
.
assertProtoEquals
(
"""
features: {
feature: {
key: "foo"
value: {
bytes_list: {
value: "Hello World!"
}
}
}
}"""
,
example
)
def
test_add_feature_dict
(
self
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_feature_dict
({
'foo'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
b
'Hello World!'
])),
'bar'
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
299
,
792
,
458
]))
})
example
=
example_builder
.
example
# Use proto text to show how the entire proto would look like.
self
.
assertProtoEquals
(
"""
features: {
feature: {
key: "foo"
value: {
bytes_list: {
value: "Hello World!"
}
}
}
feature: {
key: "bar"
value: {
int64_list: {
value: 299
value: 792
value: 458
}
}
}
}"""
,
example
)
@
parameterized
.
named_parameters
(
(
'single_bytes'
,
b
'Hello World!'
,
b
'Hello World!'
),
(
'single_string'
,
'Hello World!'
,
b
'Hello World!'
))
def
test_add_single_byte_feature
(
self
,
value
,
expected_value
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_bytes_feature
(
'foo'
,
value
)
example
=
example_builder
.
example
# Use constructor to easily work with test parameters.
self
.
assertProtoEquals
(
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'foo'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
expected_value
]))
})),
example
)
@
parameterized
.
named_parameters
(
(
'multiple_bytes'
,
[
b
'Hello World!'
,
b
'Good Morning!'
],
[
b
'Hello World!'
,
b
'Good Morning!'
]),
(
'multiple_sring'
,
[
'Hello World!'
,
'Good Morning!'
],
[
b
'Hello World!'
,
b
'Good Morning!'
]))
def
test_add_multiple_bytes_feature
(
self
,
values
,
expected_values
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_bytes_feature
(
'foo'
,
values
)
example
=
example_builder
.
example
self
.
assertProtoEquals
(
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'foo'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
expected_values
))
})),
example
)
@
parameterized
.
named_parameters
(
(
'single_integer'
,
123
,
[
123
]),
(
'multiple_integers'
,
[
123
,
456
,
789
],
[
123
,
456
,
789
]))
def
test_add_ints_feature
(
self
,
value
,
expected_value
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_ints_feature
(
'bar'
,
value
)
example
=
example_builder
.
example
self
.
assertProtoEquals
(
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'bar'
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
expected_value
))
})),
example
)
@
parameterized
.
named_parameters
(
(
'single_float'
,
3.14
,
[
3.14
]),
(
'multiple_floats'
,
[
3.14
,
1.57
,
6.28
],
[
3.14
,
1.57
,
6.28
]))
def
test_add_floats_feature
(
self
,
value
,
expected_value
):
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_floats_feature
(
'baz'
,
value
)
example
=
example_builder
.
example
self
.
assertProtoEquals
(
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'baz'
:
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
expected_value
))
})),
example
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/core/tf_example_feature_key.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data classes for tf.Example proto feature keys.
Feature keys are grouped by feature types. Key names follow conventions in
go/tf-example.
"""
import
dataclasses
import
functools
from
typing
import
Optional
# Disable init function to use the one defined in base class.
dataclass
=
functools
.
partial
(
dataclasses
.
dataclass
(
init
=
False
))
@
dataclass
class
TfExampleFeatureKeyBase
:
"""Base dataclass for defining tf.Example proto feature keys.
This class defines the logic of adding prefix to feature keys. Subclasses
will define feature keys for a specific feature type in data fields.
NOTE: Please follow subclass examples in this module to define feature keys
for a new feature type.
"""
def
__init__
(
self
,
prefix
:
Optional
[
str
]
=
None
):
"""Instantiates the feature key class.
Adds a string prefix to all fields of a feature key instance if `prefix` is
not None nor empty.
Example usage:
>>> test_key = EncodedImageFeatureKey()
>>> test_key.encoded
image/encoded
>>> test_key = EncodedImageFeatureKey('prefix')
>>> test_key.encoded
prefix/image/encoded
Args:
prefix: A prefix string that will be added before the feature key string
with a trailing slash '/'.
"""
if
prefix
:
for
field
in
dataclasses
.
fields
(
self
):
key_name
=
field
.
name
key_value
=
getattr
(
self
,
key_name
)
setattr
(
self
,
key_name
,
f
'
{
prefix
}
/
{
key_value
}
'
)
official/core/tf_example_feature_key_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_example_feature_key."""
import
dataclasses
import
inspect
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
from
official.core
import
tf_example_feature_key
@
tf_example_feature_key
.
dataclass
class
TestFeatureKey
(
tf_example_feature_key
.
TfExampleFeatureKeyBase
):
test
:
str
=
'foo/bar'
class
TfExampleFeatureKeyTest
(
parameterized
.
TestCase
):
def
test_add_prefix_success
(
self
):
test_key
=
TestFeatureKey
(
'prefix'
)
self
.
assertEqual
(
test_key
.
test
,
'prefix/foo/bar'
)
@
parameterized
.
parameters
(
None
,
''
)
def
test_add_prefix_skip_success
(
self
,
prefix
):
test_key
=
TestFeatureKey
(
prefix
)
self
.
assertEqual
(
test_key
.
test
,
'foo/bar'
)
def
test_all_feature_key_classes_are_valid
(
self
):
for
_
,
obj
in
inspect
.
getmembers
(
tf_example_feature_key
):
if
inspect
.
isclass
(
obj
):
self
.
assertTrue
(
dataclasses
.
is_dataclass
(
obj
))
self
.
assertTrue
(
issubclass
(
obj
,
tf_example_feature_key
.
TfExampleFeatureKeyBase
))
if
__name__
==
'__main__'
:
absltest
.
main
()
official/core/train_lib.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,7 +15,7 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import
os
from
typing
import
Any
,
Mapping
,
Optional
,
Tuple
from
typing
import
Any
,
Mapping
,
Optional
,
Tuple
,
List
# Import libraries
...
...
@@ -32,7 +32,29 @@ from official.core import train_utils
maybe_create_best_ckpt_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
def
run_experiment
(
class
OrbitExperimentRunner
:
"""Runs experiment with Orbit training loop.
The default experiment runner for model garden experiments. User can
customize the experiment pipeline by subclassing this class and replacing
components or functions.
For example, an experiment runner with customized checkpoint manager:
```python
class MyExpRunnerWithExporter(AbstractExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
return MyCheckpointManager(*args)
# In user code
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
```
Similar override can be done to other components.
"""
def
__init__
(
self
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
mode
:
str
,
...
...
@@ -40,111 +62,245 @@ def run_experiment(
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]
:
"""
Ru
ns
tr
ain/eval configured by the experiment params
.
)
:
"""
Co
nstr
uctor
.
Args:
distribution_strategy: A distribution
distribution_
strategy.
distribution_strategy: A distribution strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval',
'train_and_eval'
or 'continuous_eval'.
mode: A 'str', specifying the mode. Can be 'train', 'eval',
'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within
the strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
self
.
strategy
=
distribution_strategy
or
tf
.
distribute
.
get_strategy
()
self
.
_params
=
params
self
.
_model_dir
=
model_dir
self
.
_mode
=
mode
self
.
_run_post_eval
=
run_post_eval
with
distribution_strategy
.
scope
():
if
not
trainer
:
trainer
=
train_utils
.
create_trainer
(
params
,
self
.
_trainer
=
trainer
or
self
.
_build_trainer
(
task
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
checkpoint_exporter
=
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
)
assert
self
.
trainer
is
not
None
self
.
_checkpoint_manager
=
self
.
_maybe_build_checkpoint_manager
()
self
.
_controller
=
self
.
_build_controller
(
trainer
=
self
.
trainer
if
'train'
in
mode
else
None
,
evaluator
=
self
.
trainer
,
save_summary
=
save_summary
,
train_actions
=
train_actions
,
eval_actions
=
eval_actions
,
controller_cls
=
controller_cls
)
@
property
def
params
(
self
)
->
config_definitions
.
ExperimentConfig
:
return
self
.
_params
@
property
def
model_dir
(
self
)
->
str
:
return
self
.
_model_dir
@
property
def
trainer
(
self
)
->
base_trainer
.
Trainer
:
return
self
.
_trainer
@
property
def
checkpoint_manager
(
self
)
->
tf
.
train
.
CheckpointManager
:
return
self
.
_checkpoint_manager
@
property
def
controller
(
self
)
->
orbit
.
Controller
:
return
self
.
_controller
def
_build_trainer
(
self
,
task
:
base_task
.
Task
,
train
:
bool
,
evaluate
:
bool
)
->
base_trainer
.
Trainer
:
"""Create trainer."""
with
self
.
strategy
.
scope
():
trainer
=
train_utils
.
create_trainer
(
self
.
params
,
task
,
train
=
train
,
evaluate
=
evaluate
,
checkpoint_exporter
=
self
.
_build_best_checkpoint_exporter
())
return
trainer
if
trainer
.
checkpoint
:
if
model_dir
is
None
:
def
_build_best_checkpoint_exporter
(
self
):
return
maybe_create_best_ckpt_exporter
(
self
.
params
,
self
.
model_dir
)
def
_maybe_build_checkpoint_manager
(
self
)
->
Optional
[
tf
.
train
.
CheckpointManager
]:
"""Maybe create a CheckpointManager."""
assert
self
.
trainer
is
not
None
if
self
.
trainer
.
checkpoint
:
if
self
.
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
self
.
trainer
.
checkpoint
,
directory
=
self
.
model_dir
,
max_to_keep
=
self
.
params
.
trainer
.
max_to_keep
,
step_counter
=
self
.
trainer
.
global_step
,
checkpoint_interval
=
self
.
params
.
trainer
.
checkpoint_interval
,
init_fn
=
self
.
trainer
.
initialize
)
else
:
checkpoint_manager
=
None
return
checkpoint_manager
def
_build_controller
(
self
,
trainer
,
evaluator
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
orbit
.
Controller
:
"""Builds a Orbit controler."""
train_actions
=
[]
if
not
train_actions
else
train_actions
if
trainer
:
train_actions
+=
actions
.
get_train_actions
(
self
.
params
,
trainer
,
self
.
model_dir
,
checkpoint_manager
=
self
.
checkpoint_manager
)
eval_actions
=
[]
if
not
eval_actions
else
eval_actions
if
evaluator
:
eval_actions
+=
actions
.
get_eval_actions
(
self
.
params
,
evaluator
,
self
.
model_dir
)
controller
=
controller_cls
(
strategy
=
distribution_strategy
,
trainer
=
trainer
if
'train'
in
mode
else
None
,
evaluator
=
trainer
,
global_step
=
trainer
.
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
params
.
trainer
.
validation_summary_subdir
)
if
strategy
=
self
.
strategy
,
trainer
=
trainer
,
evaluator
=
evaluator
,
global_step
=
self
.
trainer
.
global_step
,
steps_per_loop
=
self
.
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
self
.
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
self
.
params
.
trainer
.
validation_summary_subdir
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
summary_interval
=
self
.
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_
train_actions
(
params
,
trainer
,
model_dir
,
checkpoint_manager
=
checkpoint_manager
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
train_actions
=
train_actions
,
eval_actions
=
eval_actions
)
return
controller
def
run
(
self
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Run experiments by mode.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
mode
=
self
.
_mode
params
=
self
.
params
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_
strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
with
self
.
strategy
.
scope
():
if
mode
==
'train'
or
mode
==
'train_and_post_eval'
:
self
.
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
self
.
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
self
.
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
if
self
.
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
self
.
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
num_params
=
train_utils
.
try_count_params
(
trainer
.
model
)
num_params
=
train_utils
.
try_count_params
(
self
.
trainer
.
model
)
if
num_params
is
not
None
:
logging
.
info
(
'Number of trainable params in model: %f Millions.'
,
num_params
/
10.
**
6
)
flops
=
train_utils
.
try_count_flops
(
trainer
.
model
)
flops
=
train_utils
.
try_count_flops
(
self
.
trainer
.
model
)
if
flops
is
not
None
:
logging
.
info
(
'FLOPs (multi-adds) in model: %f Billions.'
,
flops
/
10.
**
9
/
2
)
if
run_post_eval
:
with
distribution_
strategy
.
scope
():
return
trainer
.
model
,
train
er
.
evaluate
(
tf
.
convert_to_tensor
(
params
.
trainer
.
validation_steps
)
)
if
self
.
_
run_post_eval
or
mode
==
'train_and_post_eval'
:
with
self
.
strategy
.
scope
():
return
self
.
trainer
.
model
,
self
.
controll
er
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
else
:
return
trainer
.
model
,
{}
return
self
.
trainer
.
model
,
{}
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
runner
=
OrbitExperimentRunner
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
mode
,
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
,
save_summary
=
save_summary
,
train_actions
=
train_actions
,
eval_actions
=
eval_actions
,
trainer
=
trainer
,
controller_cls
=
controller_cls
,
)
return
runner
.
run
()
Prev
1
2
3
4
5
6
7
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment