Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -33,57 +33,6 @@ ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
class Recovery:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self,
loss_upper_bound: float,
checkpoint_manager: tf.train.CheckpointManager,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.checkpoint_manager = checkpoint_manager
def should_recover(self, loss_value, global_step):
if tf.math.is_nan(loss_value):
return True
if (global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
return True
return False
def maybe_recover(self, loss_value, global_step):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if not self.should_recover(loss_value, global_step):
return
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
"The loss value is NaN or out of range after training loop and "
f"this happens {self.recover_counter} times.")
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d.", checkpoint_path, loss_value, global_step)
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Trainer class for both sync and async Strategy."""
......@@ -370,6 +319,11 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint."""
return self._checkpoint
@property
def checkpoint_exporter(self):
"""Accesses the checkpoint exporter."""
return self._checkpoint_exporter
def train_loop_end(self):
"""See base class."""
self.join()
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -150,30 +150,6 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
return self.eval_global_step.numpy()
class RecoveryTest(tf.test.TestCase):
def test_recovery_module(self):
ckpt = tf.train.Checkpoint(v=tf.Variable(1, dtype=tf.int32))
model_dir = self.get_temp_dir()
manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1)
recovery_module = trainer_lib.Recovery(
loss_upper_bound=1.0,
checkpoint_manager=manager,
recovery_begin_steps=1,
recovery_max_trials=1)
self.assertFalse(recovery_module.should_recover(1.1, 0))
self.assertFalse(recovery_module.should_recover(0.1, 1))
self.assertTrue(recovery_module.should_recover(1.1, 2))
# First triggers the recovery once.
recovery_module.maybe_recover(1.1, 10)
# Second time, it raises.
with self.assertRaisesRegex(
RuntimeError, 'The loss value is NaN .*'):
recovery_module.maybe_recover(1.1, 10)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
......@@ -343,7 +319,9 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,6 +19,7 @@ from typing import Optional, Sequence, Union
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config
from official.modeling.privacy import configs as dp_configs
OptimizationConfig = optimization_config.OptimizationConfig
......@@ -74,7 +75,35 @@ class DataConfig(base_config.Config):
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
enable_shared_tf_data_service_between_parallel_trainers: A bool. When set to
true, only a single tf.data service will be started, and it will be shared
between all the trainer run simultaneously, e.g. using vizier to tune
hyperparameters. This will save CPU and RAM resources compared to running
separate tf.data service for each trainer. Notice that if batch size is
different for different trainers, the field
apply_tf_data_service_before_batching also needs to be true so that only a
single tf.data service instance will be created. In this case, tf.data
service will be applied before batching operation. So make sure to not
apply any processing steps after batching (e.g. in postprocess_fn) since
they wouldn't be paralleled by tf.data service and may slow down your
tf.data pipeline. When using shared tf.data service, the tf.data dataset
must be infinite, and slow trainer may skip certain training examples.
More details about shared tf.data service can be found at:
https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
apply_tf_data_service_before_batching: A bool. If set to True, tf.data
service will be applied before batching operation. This is useful to make
sure only a single tf.data service instance is created when
enable_shared_tf_data_service_between_parallel_trainers is true and batch
size is changing between parallel trainers.
trainer_id: A string. The id of the trainer if there are multiple parallel
trainer running at the same time, e.g. in vizier tuning case. It will be
automatically set if this field is needed. Users does not need to set it
when creating experiment configs.
seed: An optional seed to use for deterministic shuffling/preprocessing.
prefetch_buffer_size: An int specifying the buffer size of prefetch
datasets. If None, the buffer size is autotuned. Specifying this is useful
in case autotuning uses up too much memory by making the buffer size too
high.
"""
input_path: Union[Sequence[str], str, base_config.Config] = ""
tfds_name: str = ""
......@@ -94,7 +123,11 @@ class DataConfig(base_config.Config):
tfds_data_dir: str = ""
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
enable_shared_tf_data_service_between_parallel_trainers: bool = False
apply_tf_data_service_before_batching: bool = False
trainer_id: Optional[str] = None
seed: Optional[int] = None
prefetch_buffer_size: Optional[int] = None
@dataclasses.dataclass
......@@ -189,8 +222,8 @@ class TrainerConfig(base_config.Config):
is only used continuous_train_and_eval and continuous_eval modes. Default
value is 1 hrs.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_steps: number of eval steps. If -1, the entire eval dataset is
used.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
......@@ -240,11 +273,17 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass
class TaskConfig(base_config.Config):
"""Config passed to task."""
init_checkpoint: str = ""
model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
name: Optional[str] = None
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -67,6 +67,15 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
if inference_step is not None:
self.inference_step = functools.partial(inference_step, model=self.model)
else:
if issubclass(type(model), tf.keras.Model):
# Default to self.model.call instead of self.model.__call__ to avoid
# keras tracing logic designed for training.
# Since most of Model Garden's call doesn't not have training kwargs
# or the default is False, we don't pass anything here.
# Please pass custom inference step if your model has training=True as
# default.
self.inference_step = self.model.call
else:
self.inference_step = functools.partial(
self.model.__call__, training=False)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""File writer functions for dataset preparation, infra validation, and unit tests."""
import io
from typing import Optional, Sequence, Union
import tensorflow as tf
def write_small_dataset(examples: Sequence[Union[tf.train.Example,
tf.train.SequenceExample]],
output_path: str,
file_type: str = 'tfrecord') -> None:
"""Writes `examples` to a file at `output_path` with type `file_type`.
CAVEAT: This function is not recommended for writing large datasets, since it
will loop through `examples` and perform write operation sequentially.
Args:
examples: List of tf.train.Example or tf.train.SequenceExample.
output_path: Output path for the dataset.
file_type: A string indicating the file format, could be: 'tfrecord',
'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
string is case insensitive.
"""
file_type = file_type.lower()
if file_type == 'tfrecord' or file_type == 'tfrecords':
_write_tfrecord(examples, output_path)
elif file_type == 'tfrecord_compressed' or file_type == 'tfrecords_gzip':
_write_tfrecord(examples, output_path,
tf.io.TFRecordOptions(compression_type='GZIP'))
elif file_type == 'riegeli':
_write_riegeli(examples, output_path)
else:
raise ValueError(f'Unknown file_type: {file_type}')
def _write_tfrecord(examples: Sequence[Union[tf.train.Example,
tf.train.SequenceExample]],
output_path: str,
options: Optional[tf.io.TFRecordOptions] = None) -> None:
"""Writes `examples` to a TFRecord file at `output_path`.
Args:
examples: A list of tf.train.Example.
output_path: Output path for the dataset.
options: Options used for manipulating TFRecord files.
"""
with tf.io.TFRecordWriter(output_path, options) as writer:
for example in examples:
writer.write(example.SerializeToString())
def _write_riegeli(examples: Sequence[Union[tf.train.Example,
tf.train.SequenceExample]],
output_path: str) -> None:
"""Writes `examples` to a Riegeli file at `output_path`.
Args:
examples: A list of tf.train.Example.
output_path: Output path for the dataset.
"""
with io.FileIO(output_path, 'wb') as fileio:
import riegeli # pylint: disable=g-import-not-at-top
with riegeli.RecordWriter(fileio) as writer:
writer.write_messages(examples)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for file_writers."""
import os
from absl.testing import parameterized
import tensorflow as tf
from official.core import file_writers
from official.core import tf_example_builder
class FileWritersTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_bytes_feature('foo', 'Hello World!')
self._example = example_builder.example
@parameterized.parameters('tfrecord', 'TFRecord', 'tfrecords',
'tfrecord_compressed', 'TFRecord_Compressed',
'tfrecords_gzip')
def test_write_small_dataset_success(self, file_type):
temp_dir = self.create_tempdir()
temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
file_writers.write_small_dataset([self._example], temp_dataset_file,
file_type)
self.assertTrue(os.path.exists(temp_dataset_file))
def test_write_small_dataset_unrecognized_format(self):
file_type = 'bar'
temp_dir = self.create_tempdir()
temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
with self.assertRaises(ValueError):
file_writers.write_small_dataset([self._example], temp_dataset_file,
file_type)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -160,16 +160,38 @@ def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
"""Reads a dataset from tfds."""
# No op if exist.
tfds_builder.download_and_prepare()
decoders = {}
if tfds_skip_decoding_feature:
for skip_feature in tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
if tfds_builder.info.splits:
num_shards = len(tfds_builder.info.splits[tfds_split].file_instructions)
else:
# The tfds mock path often does not provide splits.
num_shards = 1
if input_context and num_shards < input_context.num_input_pipelines:
# The number of files in the dataset split is smaller than the number of
# input pipelines. We read the entire dataset first and then shard in the
# host memory.
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=None,
shuffle_seed=seed)
dataset = tfds_builder.as_dataset(
split=tfds_split,
shuffle_files=is_training,
as_supervised=tfds_as_supervised,
decoders=decoders,
read_config=read_config)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
else:
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=input_context,
shuffle_seed=seed)
decoders = {}
if tfds_skip_decoding_feature:
for skip_feature in tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = tfds_builder.as_dataset(
split=tfds_split,
shuffle_files=is_training,
......@@ -270,6 +292,8 @@ class InputReader:
self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn
self._seed = params.seed
self._prefetch_buffer_size = (
params.prefetch_buffer_size or tf.data.experimental.AUTOTUNE)
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
......@@ -282,13 +306,36 @@ class InputReader:
self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address)
self._tf_data_service_address = params.tf_data_service_address
self._enable_shared_tf_data_service_between_parallel_trainers = (
params.enable_shared_tf_data_service_between_parallel_trainers)
self._apply_tf_data_service_before_batching = (
params.apply_tf_data_service_before_batching)
self._trainer_id = params.trainer_id
if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
# It's necessary to add global batch size into the tf data service job
# name because when tuning batch size with vizier and tf data service is
# also enable, the tf data servce job name should be different for
# different vizier trials since once batch size is changed, from the
# tf.data perspective, the dataset is a different instance, and a
# different job name should be used for tf data service. Otherwise, the
# model would read tensors from the incorrect tf data service job, which
# would causes dimension mismatch on the batch size dimension.
self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum))
f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
f'{self.static_randnum}')
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
if self._enable_shared_tf_data_service_between_parallel_trainers:
# When shared tf.data service is enabled, only a single tf.data service
# instance should be created and shared between parallel trainers. If
# the global batch size is different across trainers,
# params.apply_tf_data_service_before_batching should be set to true
# because tf.data service with different batch sizes will be considered
# separate tf.data service instances.
self._tf_data_service_job_name = (
f'{params.tf_data_service_job_name}_{self.static_randnum}')
@property
def tfds_info(self) -> tfds.core.DatasetInfo:
......@@ -411,6 +458,19 @@ class InputReader:
dataset = dataset.repeat()
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
# Applies tf.data service before batching operations. This is useful when
# tf.data service is shared between parallel trainers, and batch size is
# changing between parallel trainers. Then batch size is changing, tf.data
# services will be considered different instances if applied after batching
# operations, which make it difficult to share between parallel trainers.
# However, if there are additional expensive operations in
# self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data
# pipeline could be slowed down. In this case, try to move these dataset
# operations into early stages if possible.
if (self._enable_shared_tf_data_service_between_parallel_trainers and
self._apply_tf_data_service_before_batching):
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
......@@ -436,13 +496,18 @@ class InputReader:
num_consumers = input_context.num_input_pipelines * (
replicas_per_input_pipeline)
range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
tfds_kwargs = {
'processing_mode': 'parallel_epochs',
'service': self._tf_data_service_address,
'job_name': self._tf_data_service_job_name,
'num_consumers': num_consumers
}
if self._enable_shared_tf_data_service_between_parallel_trainers:
raise ValueError('Shared tf.data service does not support round-robin'
' tf.data service.')
dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name,
consumer_index=base_consumer_index + i,
num_consumers=num_consumers)))
consumer_index=base_consumer_index + i, **tfds_kwargs)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset = dataset.interleave(
......@@ -451,11 +516,21 @@ class InputReader:
num_parallel_calls=replicas_per_input_pipeline,
deterministic=True)
else:
tfds_kwargs = {
'processing_mode': 'parallel_epochs',
'service': self._tf_data_service_address,
'job_name': self._tf_data_service_job_name,
}
if self._enable_shared_tf_data_service_between_parallel_trainers:
tfds_kwargs.update({
'processing_mode':
tf.data.experimental.service.ShardingPolicy.OFF,
'cross_trainer_cache':
tf.data.experimental.service.CrossTrainerCache(
trainer_id=self._trainer_id)
})
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name))
tf.data.experimental.service.distribute(**tfds_kwargs))
return dataset
def read(self,
......@@ -463,16 +538,17 @@ class InputReader:
dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
if dataset is None:
dataset = self._read_data_source(
self._matched_files, self._dataset_fn, input_context,
self._tfds_builder)
dataset = self._read_data_source(self._matched_files, self._dataset_fn,
input_context, self._tfds_builder)
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
if not (self._enable_shared_tf_data_service_between_parallel_trainers and
self._apply_tf_data_service_before_batching):
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._deterministic is not None:
options = tf.data.Options()
options.experimental_deterministic = self._deterministic
options.deterministic = self._deterministic
dataset = dataset.with_options(options)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset.prefetch(self._prefetch_buffer_size)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Registry utility."""
from absl import logging
def register(registered_collection, reg_key):
......@@ -54,6 +55,14 @@ def register(registered_collection, reg_key):
leaf_reg_key = reg_key
if leaf_reg_key in collection:
if "beta" in fn_or_cls.__module__:
# TODO(yeqing): Clean this temporary branch for beta.
logging.warn(
"Duplicate registeration of beta module "
"name %r new %r old %r", reg_key, collection[leaf_reg_key],
fn_or_cls.__module__)
return fn_or_cls
else:
raise KeyError("Function or class {} registered multiple times.".format(
leaf_reg_key))
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom checkpoint manager that also exports saved models."""
import os
import re
import time
from typing import Callable, List, Mapping, Optional, Union
from absl import logging
import tensorflow as tf
SAVED_MODULES_PATH_SUFFIX = 'saved_modules'
def make_saved_modules_directory_name(checkpoint_name: str) -> str:
return f'{checkpoint_name}_{SAVED_MODULES_PATH_SUFFIX}'
class SavedModelCheckpointManager(tf.train.CheckpointManager):
"""A CheckpointManager that also exports `SavedModel`s."""
def __init__(self,
checkpoint: tf.train.Checkpoint,
directory: str,
max_to_keep: int,
modules_to_export: Optional[Mapping[str, tf.Module]] = None,
keep_checkpoint_every_n_hours: Optional[int] = None,
checkpoint_name: str = 'ckpt',
step_counter: Optional[tf.Variable] = None,
checkpoint_interval: Optional[int] = None,
init_fn: Optional[Callable[[], None]] = None):
"""See base class."""
super().__init__(
checkpoint=checkpoint,
directory=directory,
max_to_keep=max_to_keep,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
checkpoint_name=checkpoint_name,
step_counter=step_counter,
checkpoint_interval=checkpoint_interval,
init_fn=init_fn)
self._modules_to_export = modules_to_export
self._savedmodels = self.get_existing_savedmodels()
def save(self,
checkpoint_number: Optional[int] = None,
check_interval: bool = True,
options: Optional[tf.train.CheckpointOptions] = None):
"""See base class."""
checkpoint_path = super().save(
checkpoint_number=checkpoint_number,
check_interval=check_interval,
options=options)
if not checkpoint_path: # Nothing got written.
return
if not self._modules_to_export: # No modules to export.
logging.info('Skip saving SavedModel due to empty modules_to_export.')
return checkpoint_path
# Save the models for the checkpoint that just got written.
saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
for model_name, model in self._modules_to_export.items():
signatures = getattr(model, 'saved_model_signatures', None)
tf.saved_model.save(
obj=model,
export_dir=os.path.join(saved_modules_directory, model_name),
signatures=signatures)
saved_modules_directories_to_keep = [
make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
]
existing_saved_modules_dirs = self.get_existing_savedmodels()
self._savedmodels = []
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
for saved_modules_dir_to_keep in saved_modules_directories_to_keep:
if saved_modules_dir_to_keep in existing_saved_modules_dirs:
self._savedmodels.append(saved_modules_dir_to_keep)
for existing_saved_modules_dir in existing_saved_modules_dirs:
if existing_saved_modules_dir not in self._savedmodels:
tf.io.gfile.rmtree(existing_saved_modules_dir)
return checkpoint_path
def get_existing_savedmodels(self) -> List[str]:
"""Gets a list of all existing SavedModel paths in `directory`.
Returns:
A list of all existing SavedModel paths.
"""
saved_modules_glob = make_saved_modules_directory_name(
self._checkpoint_prefix + '-*')
return tf.io.gfile.glob(saved_modules_glob)
@property
def latest_savedmodel(self) -> Union[str, None]:
"""The path of the most recent SavedModel in `directory`.
Returns:
The latest SavedModel path. If there are no SavedModels, returns `None`.
"""
if self._savedmodels:
return self._savedmodels[-1]
return None
@property
def savedmodels(self) -> List[str]:
"""A list of managed SavedModels.
Returns:
A list of SavedModel paths, sorted from oldest to newest.
"""
return self._savedmodels
@property
def modules_to_export(self) -> Union[Mapping[str, tf.Module], None]:
return self._modules_to_export
def get_savedmodel_number_from_path(self,
savedmodel_path: str) -> Union[int, None]:
"""Gets the savedmodel_number/checkpoint_number from savedmodel filepath.
The savedmodel_number is global step when using with orbit controller.
Args:
savedmodel_path: savedmodel directory path.
Returns:
Savedmodel number or None if no matched pattern found in savedmodel path.
"""
pattern = rf'\d+_{SAVED_MODULES_PATH_SUFFIX}$'
savedmodel_number = re.search(pattern, savedmodel_path)
if savedmodel_number:
savedmodel_number = savedmodel_number.group()
return int(savedmodel_number[:-len(SAVED_MODULES_PATH_SUFFIX) - 1])
return None
def savedmodels_iterator(self,
min_interval_secs: float = 0,
timeout: Optional[float] = None,
timeout_fn: Optional[Callable[[], bool]] = None):
"""Continuously yield new SavedModel files as they appear.
The iterator only checks for new savedmodels when control flow has been
reverted to it. The logic is same to the `train.checkpoints_iterator`.
Args:
min_interval_secs: The minimum number of seconds between yielding
savedmodels.
timeout: The maximum number of seconds to wait between savedmodels. If
left as `None`, then the process will wait indefinitely.
timeout_fn: Optional function to call after a timeout. If the function
returns True, then it means that no new savedmodels will be generated
and the iterator will exit. The function is called with no arguments.
Yields:
String paths to latest SavedModel files as they arrive.
"""
savedmodel_path = None
while True:
new_savedmodel_path = self.wait_for_new_savedmodel(
savedmodel_path, timeout=timeout)
if new_savedmodel_path is None:
if not timeout_fn:
# timed out
logging.info('Timed-out waiting for a savedmodel.')
return
if timeout_fn():
# The timeout_fn indicated that we are truly done.
return
else:
# The timeout_fn indicated that more savedmodels may come.
continue
start = time.time()
savedmodel_path = new_savedmodel_path
yield savedmodel_path
time_to_next_eval = start + min_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
def wait_for_new_savedmodel(
self,
last_savedmodel: Optional[str] = None,
seconds_to_sleep: float = 1.0,
timeout: Optional[float] = None) -> Union[str, None]:
"""Waits until a new savedmodel file is found.
Args:
last_savedmodel: The last savedmodel path used or `None` if we're
expecting a savedmodel for the first time.
seconds_to_sleep: The number of seconds to sleep for before looking for a
new savedmodel.
timeout: The maximum number of seconds to wait. If left as `None`, then
the process will wait indefinitely.
Returns:
A new savedmodel path, or None if the timeout was reached.
"""
logging.info('Waiting for new savedmodel at %s', self._directory)
stop_time = time.time() + timeout if timeout is not None else None
last_savedmodel_number = 0
if last_savedmodel:
last_savedmodel_number = self.get_savedmodel_number_from_path(
last_savedmodel)
while True:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None
existing_savedmodels = {}
for savedmodel_path in self.get_existing_savedmodels():
savedmodel_number = self.get_savedmodel_number_from_path(
savedmodel_path)
if savedmodel_number is not None:
existing_savedmodels[savedmodel_number] = savedmodel_path
# Find the first savedmodel with larger step number as next savedmodel.
savedmodel_path = None
existing_savedmodels = dict(sorted(existing_savedmodels.items()))
for savedmodel_number in existing_savedmodels:
if savedmodel_number > last_savedmodel_number:
savedmodel_path = existing_savedmodels[savedmodel_number]
break
if savedmodel_path:
logging.info('Found new savedmodel at %s', savedmodel_path)
return savedmodel_path
else:
time.sleep(seconds_to_sleep)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import Iterable
import tensorflow as tf
from official.core import savedmodel_checkpoint_manager
def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
for model_name in models:
if not tf.io.gfile.isdir(
os.path.join(
savedmodel_checkpoint_manager.make_saved_modules_directory_name(
checkpoint_path), model_name)):
return False
return True
class CheckpointManagerTest(tf.test.TestCase):
def _create_manager(self, max_to_keep: int = 1) -> tf.train.CheckpointManager:
"""Sets up SavedModelCheckpointManager object.
Args:
max_to_keep: max number of savedmodels to keep.
Returns:
created savedmodel manager.
"""
models = {
'model_1':
tf.keras.Sequential(
layers=[tf.keras.layers.Dense(8, input_shape=(16,))]),
'model_2':
tf.keras.Sequential(
layers=[tf.keras.layers.Dense(16, input_shape=(32,))]),
}
checkpoint = tf.train.Checkpoint()
manager = savedmodel_checkpoint_manager.SavedModelCheckpointManager(
checkpoint=checkpoint,
directory=self.get_temp_dir(),
max_to_keep=max_to_keep,
modules_to_export=models)
return manager
def test_max_to_keep(self):
manager = self._create_manager()
models = manager.modules_to_export
first_path = manager.save()
second_path = manager.save()
savedmodel = savedmodel_checkpoint_manager.make_saved_modules_directory_name(
manager.latest_checkpoint)
self.assertEqual(savedmodel, manager.latest_savedmodel)
self.assertTrue(_models_exist(second_path, models.keys()))
self.assertFalse(_models_exist(first_path, models.keys()))
def test_returns_none_after_timeout(self):
manager = self._create_manager()
start = time.time()
ret = manager.wait_for_new_savedmodel(
None, timeout=1.0, seconds_to_sleep=0.5)
end = time.time()
self.assertIsNone(ret)
# We've waited 0.5 second.
self.assertGreater(end, start + 0.5)
# The timeout kicked in.
self.assertLess(end, start + 0.6)
def test_saved_model_iterator(self):
manager = self._create_manager(max_to_keep=2)
self.assertIsNotNone(manager.save(checkpoint_number=1))
self.assertIsNotNone(manager.save(checkpoint_number=2))
self.assertIsNotNone(manager.save(checkpoint_number=3))
# Savedmodels are in time order.
expected_savedmodels = manager.savedmodels
# Order not guaranteed.
existing_savedmodels = manager.get_existing_savedmodels()
savedmodels = list(manager.savedmodels_iterator(timeout=3.0))
self.assertEqual(savedmodels, expected_savedmodels)
self.assertEqual(set(savedmodels), set(existing_savedmodels))
def test_saved_model_iterator_timeout_fn(self):
manager = self._create_manager()
timeout_fn_calls = [0]
def timeout_fn():
timeout_fn_calls[0] += 1
return timeout_fn_calls[0] > 3
results = list(
manager.savedmodels_iterator(timeout=0.1, timeout_fn=timeout_fn))
self.assertEqual([], results)
self.assertEqual(4, timeout_fn_calls[0])
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Builder class for preparing tf.train.Example."""
# https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7
from __future__ import annotations
from typing import Mapping, Sequence, Union
import numpy as np
import tensorflow as tf
BytesValueType = Union[bytes, Sequence[bytes], str, Sequence[str]]
_to_array = lambda v: [v] if not isinstance(v, (list, np.ndarray)) else v
_to_bytes = lambda v: v.encode() if isinstance(v, str) else v
_to_bytes_array = lambda v: list(map(_to_bytes, _to_array(v)))
class TfExampleBuilder(object):
"""Builder class for preparing tf.train.Example.
Read API doc at https://www.tensorflow.org/api_docs/python/tf/train/Example.
Example usage:
>>> example_builder = TfExampleBuilder()
>>> example = (
example_builder.add_bytes_feature('feature_a', 'foobarbaz')
.add_ints_feature('feature_b', [1, 2, 3])
.example)
"""
def __init__(self) -> None:
self._example = tf.train.Example()
@property
def example(self) -> tf.train.Example:
"""Returns a copy of the generated tf.train.Example proto."""
return self._example
@property
def serialized_example(self) -> str:
"""Returns a serialized string of the generated tf.train.Example proto."""
return self._example.SerializeToString()
def set(self, example: tf.train.Example) -> TfExampleBuilder:
"""Sets the example."""
self._example = example
return self
def reset(self) -> TfExampleBuilder:
"""Resets the example to an empty proto."""
self._example = tf.train.Example()
return self
###### Basic APIs for primitive data types ######
def add_feature_dict(
self, feature_dict: Mapping[str, tf.train.Feature]) -> TfExampleBuilder:
"""Adds the predefined `feature_dict` to the example.
Note: Please prefer to using feature-type-specific methods.
Args:
feature_dict: A dictionary from tf.Example feature key to
tf.train.Feature.
Returns:
The builder object for subsequent method calls.
"""
for k, v in feature_dict.items():
self._example.features.feature[k].CopyFrom(v)
return self
def add_feature(self, key: str,
feature: tf.train.Feature) -> TfExampleBuilder:
"""Adds predefined `feature` with `key` to the example.
Args:
key: String key of the feature.
feature: The feature to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
self._example.features.feature[key].CopyFrom(feature)
return self
def add_bytes_feature(self, key: str,
value: BytesValueType) -> TfExampleBuilder:
"""Adds byte(s) or string(s) with `key` to the example.
Args:
key: String key of the feature.
value: The byte(s) or string(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return self.add_feature(
key,
tf.train.Feature(
bytes_list=tf.train.BytesList(value=_to_bytes_array(value))))
def add_ints_feature(self, key: str,
value: Union[int, Sequence[int]]) -> TfExampleBuilder:
"""Adds integer(s) with `key` to the example.
Args:
key: String key of the feature.
value: The integer(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return self.add_feature(
key,
tf.train.Feature(int64_list=tf.train.Int64List(value=_to_array(value))))
def add_floats_feature(
self, key: str, value: Union[float, Sequence[float]]) -> TfExampleBuilder:
"""Adds float(s) with `key` to the example.
Args:
key: String key of the feature.
value: The float(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return self.add_feature(
key,
tf.train.Feature(float_list=tf.train.FloatList(value=_to_array(value))))
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_example_builder.
See `test_add_image_matrix_feature_with_fake_image` for the typical structure of
a unit test.
"""
from absl.testing import parameterized
import tensorflow as tf
from official.core import tf_example_builder
class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
def test_init_an_empty_example(self):
example_builder = tf_example_builder.TfExampleBuilder()
example = example_builder.example
self.assertProtoEquals('', example)
def test_init_an_empty_serialized_example(self):
example_builder = tf_example_builder.TfExampleBuilder()
example = example_builder.serialized_example
self.assertProtoEquals('', example)
def test_add_feature(self):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_feature(
'foo',
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b'Hello World!'])))
example = example_builder.example
# Use proto text to show how the entire proto would look like.
self.assertProtoEquals(
"""
features: {
feature: {
key: "foo"
value: {
bytes_list: {
value: "Hello World!"
}
}
}
}""", example)
def test_add_feature_dict(self):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_feature_dict({
'foo':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b'Hello World!'])),
'bar':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[299, 792, 458]))
})
example = example_builder.example
# Use proto text to show how the entire proto would look like.
self.assertProtoEquals(
"""
features: {
feature: {
key: "foo"
value: {
bytes_list: {
value: "Hello World!"
}
}
}
feature: {
key: "bar"
value: {
int64_list: {
value: 299
value: 792
value: 458
}
}
}
}""", example)
@parameterized.named_parameters(
('single_bytes', b'Hello World!', b'Hello World!'),
('single_string', 'Hello World!', b'Hello World!'))
def test_add_single_byte_feature(self, value, expected_value):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_bytes_feature('foo', value)
example = example_builder.example
# Use constructor to easily work with test parameters.
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'foo':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[expected_value]))
})), example)
@parameterized.named_parameters(
('multiple_bytes', [b'Hello World!', b'Good Morning!'
], [b'Hello World!', b'Good Morning!']),
('multiple_sring', ['Hello World!', 'Good Morning!'
], [b'Hello World!', b'Good Morning!']))
def test_add_multiple_bytes_feature(self, values, expected_values):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_bytes_feature('foo', values)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'foo':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=expected_values))
})), example)
@parameterized.named_parameters(
('single_integer', 123, [123]),
('multiple_integers', [123, 456, 789], [123, 456, 789]))
def test_add_ints_feature(self, value, expected_value):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_ints_feature('bar', value)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'bar':
tf.train.Feature(
int64_list=tf.train.Int64List(value=expected_value))
})), example)
@parameterized.named_parameters(
('single_float', 3.14, [3.14]),
('multiple_floats', [3.14, 1.57, 6.28], [3.14, 1.57, 6.28]))
def test_add_floats_feature(self, value, expected_value):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_floats_feature('baz', value)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'baz':
tf.train.Feature(
float_list=tf.train.FloatList(value=expected_value))
})), example)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data classes for tf.Example proto feature keys.
Feature keys are grouped by feature types. Key names follow conventions in
go/tf-example.
"""
import dataclasses
import functools
from typing import Optional
# Disable init function to use the one defined in base class.
dataclass = functools.partial(dataclasses.dataclass(init=False))
@dataclass
class TfExampleFeatureKeyBase:
"""Base dataclass for defining tf.Example proto feature keys.
This class defines the logic of adding prefix to feature keys. Subclasses
will define feature keys for a specific feature type in data fields.
NOTE: Please follow subclass examples in this module to define feature keys
for a new feature type.
"""
def __init__(self, prefix: Optional[str] = None):
"""Instantiates the feature key class.
Adds a string prefix to all fields of a feature key instance if `prefix` is
not None nor empty.
Example usage:
>>> test_key = EncodedImageFeatureKey()
>>> test_key.encoded
image/encoded
>>> test_key = EncodedImageFeatureKey('prefix')
>>> test_key.encoded
prefix/image/encoded
Args:
prefix: A prefix string that will be added before the feature key string
with a trailing slash '/'.
"""
if prefix:
for field in dataclasses.fields(self):
key_name = field.name
key_value = getattr(self, key_name)
setattr(self, key_name, f'{prefix}/{key_value}')
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_example_feature_key."""
import dataclasses
import inspect
from absl.testing import absltest
from absl.testing import parameterized
from official.core import tf_example_feature_key
@tf_example_feature_key.dataclass
class TestFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
test: str = 'foo/bar'
class TfExampleFeatureKeyTest(parameterized.TestCase):
def test_add_prefix_success(self):
test_key = TestFeatureKey('prefix')
self.assertEqual(test_key.test, 'prefix/foo/bar')
@parameterized.parameters(None, '')
def test_add_prefix_skip_success(self, prefix):
test_key = TestFeatureKey(prefix)
self.assertEqual(test_key.test, 'foo/bar')
def test_all_feature_key_classes_are_valid(self):
for _, obj in inspect.getmembers(tf_example_feature_key):
if inspect.isclass(obj):
self.assertTrue(dataclasses.is_dataclass(obj))
self.assertTrue(
issubclass(obj, tf_example_feature_key.TfExampleFeatureKeyBase))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,7 +15,7 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Optional, Tuple
from typing import Any, Mapping, Optional, Tuple, List
# Import libraries
......@@ -32,7 +32,29 @@ from official.core import train_utils
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment(
class OrbitExperimentRunner:
"""Runs experiment with Orbit training loop.
The default experiment runner for model garden experiments. User can
customize the experiment pipeline by subclassing this class and replacing
components or functions.
For example, an experiment runner with customized checkpoint manager:
```python
class MyExpRunnerWithExporter(AbstractExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
return MyCheckpointManager(*args)
# In user code
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
```
Similar override can be done to other components.
"""
def __init__(
self,
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
......@@ -40,111 +62,245 @@ def run_experiment(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
):
"""Constructor.
Args:
distribution_strategy: A distribution distribution_strategy.
distribution_strategy: A distribution strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
mode: A 'str', specifying the mode. Can be 'train', 'eval',
'train_and_eval' or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within
the strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params
self._model_dir = model_dir
self._mode = mode
self._run_post_eval = run_post_eval
with distribution_strategy.scope():
if not trainer:
trainer = train_utils.create_trainer(
params,
self._trainer = trainer or self._build_trainer(
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
evaluate=('eval' in mode) or run_post_eval)
assert self.trainer is not None
self._checkpoint_manager = self._maybe_build_checkpoint_manager()
self._controller = self._build_controller(
trainer=self.trainer if 'train' in mode else None,
evaluator=self.trainer,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
controller_cls=controller_cls)
@property
def params(self) -> config_definitions.ExperimentConfig:
return self._params
@property
def model_dir(self) -> str:
return self._model_dir
@property
def trainer(self) -> base_trainer.Trainer:
return self._trainer
@property
def checkpoint_manager(self) -> tf.train.CheckpointManager:
return self._checkpoint_manager
@property
def controller(self) -> orbit.Controller:
return self._controller
def _build_trainer(self, task: base_task.Task, train: bool,
evaluate: bool) -> base_trainer.Trainer:
"""Create trainer."""
with self.strategy.scope():
trainer = train_utils.create_trainer(
self.params,
task,
train=train,
evaluate=evaluate,
checkpoint_exporter=self._build_best_checkpoint_exporter())
return trainer
if trainer.checkpoint:
if model_dir is None:
def _build_best_checkpoint_exporter(self):
return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
def _maybe_build_checkpoint_manager(
self) -> Optional[tf.train.CheckpointManager]:
"""Maybe create a CheckpointManager."""
assert self.trainer is not None
if self.trainer.checkpoint:
if self.model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
self.trainer.checkpoint,
directory=self.model_dir,
max_to_keep=self.params.trainer.max_to_keep,
step_counter=self.trainer.global_step,
checkpoint_interval=self.params.trainer.checkpoint_interval,
init_fn=self.trainer.initialize)
else:
checkpoint_manager = None
return checkpoint_manager
def _build_controller(self,
trainer,
evaluator,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller:
"""Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions
if trainer:
train_actions += actions.get_train_actions(
self.params,
trainer,
self.model_dir,
checkpoint_manager=self.checkpoint_manager)
eval_actions = [] if not eval_actions else eval_actions
if evaluator:
eval_actions += actions.get_eval_actions(self.params, evaluator,
self.model_dir)
controller = controller_cls(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir,
params.trainer.validation_summary_subdir) if
strategy=self.strategy,
trainer=trainer,
evaluator=evaluator,
global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager,
summary_dir=os.path.join(self.model_dir, 'train') if
(save_summary) else None,
eval_summary_dir=os.path.join(
self.model_dir, self.params.trainer.validation_summary_subdir) if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
summary_interval=self.params.trainer.summary_interval if
(save_summary) else None,
train_actions=actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
eval_actions=actions.get_eval_actions(params, trainer, model_dir))
train_actions=train_actions,
eval_actions=eval_actions)
return controller
def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Run experiments by mode.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
mode = self._mode
params = self.params
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
with self.strategy.scope():
if mode == 'train' or mode == 'train_and_post_eval':
self.controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
self.controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
self.controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
if self.trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
self.controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
num_params = train_utils.try_count_params(trainer.model)
num_params = train_utils.try_count_params(self.trainer.model)
if num_params is not None:
logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6)
flops = train_utils.try_count_flops(trainer.model)
flops = train_utils.try_count_flops(self.trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
if self._run_post_eval or mode == 'train_and_post_eval':
with self.strategy.scope():
return self.trainer.model, self.controller.evaluate(
steps=params.trainer.validation_steps)
else:
return trainer.model, {}
return self.trainer.model, {}
def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
runner = OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
trainer=trainer,
controller_cls=controller_cls,
)
return runner.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment