Commit f1add1bc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Extend checkpoint manager to be SavedModel checkpoint manager to store a...

Extend checkpoint manager to be SavedModel checkpoint manager to store a SavedModel when saving a checkpoint.

PiperOrigin-RevId: 458513556
parent 96ed89d1
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Core is shared by both `nlp` and `vision`.""" """Core is shared by both `nlp` and `vision`."""
from official.core import actions from official.core import actions
from official.core import base_task from official.core import base_task
from official.core import base_trainer from official.core import base_trainer
...@@ -21,6 +22,7 @@ from official.core import exp_factory ...@@ -21,6 +22,7 @@ from official.core import exp_factory
from official.core import export_base from official.core import export_base
from official.core import input_reader from official.core import input_reader
from official.core import registry from official.core import registry
from official.core import savedmodel_checkpoint_manager
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom checkpoint manager that also exports saved models."""
import os
import re
from typing import Callable, Mapping, Optional
from absl import logging
import tensorflow as tf
def make_saved_modules_directory_name(checkpoint_name: str) -> str:
return f'{checkpoint_name}_saved_modules'
class SavedModelCheckpointManager(tf.train.CheckpointManager):
"""A CheckpointManager that also exports `SavedModel`s."""
def __init__(self,
checkpoint: tf.train.Checkpoint,
directory: str,
max_to_keep: int,
modules_to_export: Optional[Mapping[str, tf.Module]] = None,
keep_checkpoint_every_n_hours: Optional[int] = None,
checkpoint_name: str = 'ckpt',
step_counter: Optional[tf.Variable] = None,
checkpoint_interval: Optional[int] = None,
init_fn: Optional[Callable[[], None]] = None):
"""See base class."""
super().__init__(
checkpoint=checkpoint,
directory=directory,
max_to_keep=max_to_keep,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
checkpoint_name=checkpoint_name,
step_counter=step_counter,
checkpoint_interval=checkpoint_interval,
init_fn=init_fn)
self._modules_to_export = modules_to_export
def save(self,
checkpoint_number=None,
check_interval: bool = True,
options: Optional[tf.train.CheckpointOptions] = None):
"""See base class."""
checkpoint_path = super().save(
checkpoint_number=checkpoint_number,
check_interval=check_interval,
options=options)
if not checkpoint_path: # Nothing got written.
return
if not self._modules_to_export: # No modules to export.
logging.info('Skip saving SavedModel due to empty modules_to_export.')
return checkpoint_path
# Save the models for the checkpoint that just got written.
saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
for model_name, model in self._modules_to_export.items():
tf.saved_model.save(
obj=model,
export_dir=os.path.join(saved_modules_directory, model_name))
# `checkpoint_path` ends in `-[\d]+`. We want to glob for all existing
# checkpoints, and we use the .index file for that.
checkpoint_glob = re.sub(r'\d+$', '*.index', checkpoint_path)
existing_checkpoint_files = tf.io.gfile.glob(checkpoint_glob)
saved_modules_directories_to_keep = [
make_saved_modules_directory_name(os.path.splitext(ckpt_index)[0])
for ckpt_index in existing_checkpoint_files
]
saved_modules_glob = re.sub(r'\d+_saved_modules$', '*_saved_modules',
saved_modules_directory)
for existing_saved_modules_dir in tf.io.gfile.glob(saved_modules_glob):
if (existing_saved_modules_dir not in saved_modules_directories_to_keep
and tf.io.gfile.isdir(existing_saved_modules_dir)):
tf.io.gfile.rmtree(existing_saved_modules_dir)
return checkpoint_path
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Iterable
import tensorflow as tf
from official.core import savedmodel_checkpoint_manager
def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
for model_name in models:
if not tf.io.gfile.isdir(
os.path.join(
savedmodel_checkpoint_manager.make_saved_modules_directory_name(
checkpoint_path), model_name)):
return False
return True
class CheckpointManagerTest(tf.test.TestCase):
def testSimpleTest(self):
models = {
"model_1":
tf.keras.Sequential(
layers=[tf.keras.layers.Dense(8, input_shape=(16,))]),
"model_2":
tf.keras.Sequential(
layers=[tf.keras.layers.Dense(16, input_shape=(32,))]),
}
checkpoint = tf.train.Checkpoint()
manager = savedmodel_checkpoint_manager.SavedModelCheckpointManager(
checkpoint=checkpoint,
directory=self.get_temp_dir(),
max_to_keep=1,
modules_to_export=models)
first_path = manager.save()
second_path = manager.save()
self.assertTrue(_models_exist(second_path, models.keys()))
self.assertFalse(_models_exist(first_path, models.keys()))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment