Commit 77026626 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 460263057
parent 7d45e7b9
...@@ -15,15 +15,16 @@ ...@@ -15,15 +15,16 @@
"""Custom checkpoint manager that also exports saved models.""" """Custom checkpoint manager that also exports saved models."""
import os import os
import re
from typing import Callable, Mapping, Optional from typing import Callable, Mapping, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
_SAVED_MODULES_PATH_SUFFIX = 'saved_modules'
def make_saved_modules_directory_name(checkpoint_name: str) -> str: def make_saved_modules_directory_name(checkpoint_name: str) -> str:
return f'{checkpoint_name}_saved_modules' return f'{checkpoint_name}_{_SAVED_MODULES_PATH_SUFFIX}'
class SavedModelCheckpointManager(tf.train.CheckpointManager): class SavedModelCheckpointManager(tf.train.CheckpointManager):
...@@ -50,6 +51,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager): ...@@ -50,6 +51,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
checkpoint_interval=checkpoint_interval, checkpoint_interval=checkpoint_interval,
init_fn=init_fn) init_fn=init_fn)
self._modules_to_export = modules_to_export self._modules_to_export = modules_to_export
self._savedmodels = self._get_existing_savedmodels()
def save(self, def save(self,
checkpoint_number=None, checkpoint_number=None,
...@@ -73,21 +75,49 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager): ...@@ -73,21 +75,49 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
obj=model, obj=model,
export_dir=os.path.join(saved_modules_directory, model_name)) export_dir=os.path.join(saved_modules_directory, model_name))
# `checkpoint_path` ends in `-[\d]+`. We want to glob for all existing
# checkpoints, and we use the .index file for that.
checkpoint_glob = re.sub(r'\d+$', '*.index', checkpoint_path)
existing_checkpoint_files = tf.io.gfile.glob(checkpoint_glob)
saved_modules_directories_to_keep = [ saved_modules_directories_to_keep = [
make_saved_modules_directory_name(os.path.splitext(ckpt_index)[0]) make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
for ckpt_index in existing_checkpoint_files
] ]
saved_modules_glob = re.sub(r'\d+_saved_modules$', '*_saved_modules', existing_saved_modules_dirs = self._get_existing_savedmodels()
saved_modules_directory)
self._savedmodels = []
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
for saved_modules_dir_to_keep in saved_modules_directories_to_keep:
if saved_modules_dir_to_keep in existing_saved_modules_dirs:
self._savedmodels.append(saved_modules_dir_to_keep)
for existing_saved_modules_dir in tf.io.gfile.glob(saved_modules_glob): for existing_saved_modules_dir in existing_saved_modules_dirs:
if (existing_saved_modules_dir not in saved_modules_directories_to_keep if existing_saved_modules_dir not in self._savedmodels:
and tf.io.gfile.isdir(existing_saved_modules_dir)):
tf.io.gfile.rmtree(existing_saved_modules_dir) tf.io.gfile.rmtree(existing_saved_modules_dir)
return checkpoint_path return checkpoint_path
def _get_existing_savedmodels(self):
"""Gets a list of all existing SavedModel paths in `directory`.
Returns:
A list of all existing SavedModel paths.
"""
saved_modules_glob = make_saved_modules_directory_name(
self._checkpoint_prefix + '-*')
return tf.io.gfile.glob(saved_modules_glob)
@property
def latest_savedmodel(self):
"""The path of the most recent SavedModel in `directory`.
Returns:
The latest SavedModel path. If there are no SavedModels, returns `None`.
"""
if self._savedmodels:
return self._savedmodels[-1]
return None
@property
def savedmodels(self):
"""A list of managed SavedModels.
Returns:
A list of SavedModel paths, sorted from oldest to newest.
"""
return self._savedmodels
...@@ -51,6 +51,9 @@ class CheckpointManagerTest(tf.test.TestCase): ...@@ -51,6 +51,9 @@ class CheckpointManagerTest(tf.test.TestCase):
first_path = manager.save() first_path = manager.save()
second_path = manager.save() second_path = manager.save()
savedmodel = savedmodel_checkpoint_manager.make_saved_modules_directory_name(
manager.latest_checkpoint)
self.assertEqual(savedmodel, manager.latest_savedmodel)
self.assertTrue(_models_exist(second_path, models.keys())) self.assertTrue(_models_exist(second_path, models.keys()))
self.assertFalse(_models_exist(first_path, models.keys())) self.assertFalse(_models_exist(first_path, models.keys()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment