Commit 77026626 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 460263057
parent 7d45e7b9
......@@ -15,15 +15,16 @@
"""Custom checkpoint manager that also exports saved models."""
import os
import re
from typing import Callable, Mapping, Optional
from absl import logging
import tensorflow as tf
_SAVED_MODULES_PATH_SUFFIX = 'saved_modules'
def make_saved_modules_directory_name(checkpoint_name: str) -> str:
return f'{checkpoint_name}_saved_modules'
return f'{checkpoint_name}_{_SAVED_MODULES_PATH_SUFFIX}'
class SavedModelCheckpointManager(tf.train.CheckpointManager):
......@@ -50,6 +51,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
checkpoint_interval=checkpoint_interval,
init_fn=init_fn)
self._modules_to_export = modules_to_export
self._savedmodels = self._get_existing_savedmodels()
def save(self,
checkpoint_number=None,
......@@ -73,21 +75,49 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
obj=model,
export_dir=os.path.join(saved_modules_directory, model_name))
# `checkpoint_path` ends in `-[\d]+`. We want to glob for all existing
# checkpoints, and we use the .index file for that.
checkpoint_glob = re.sub(r'\d+$', '*.index', checkpoint_path)
existing_checkpoint_files = tf.io.gfile.glob(checkpoint_glob)
saved_modules_directories_to_keep = [
make_saved_modules_directory_name(os.path.splitext(ckpt_index)[0])
for ckpt_index in existing_checkpoint_files
make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
]
saved_modules_glob = re.sub(r'\d+_saved_modules$', '*_saved_modules',
saved_modules_directory)
existing_saved_modules_dirs = self._get_existing_savedmodels()
self._savedmodels = []
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
for saved_modules_dir_to_keep in saved_modules_directories_to_keep:
if saved_modules_dir_to_keep in existing_saved_modules_dirs:
self._savedmodels.append(saved_modules_dir_to_keep)
for existing_saved_modules_dir in tf.io.gfile.glob(saved_modules_glob):
if (existing_saved_modules_dir not in saved_modules_directories_to_keep
and tf.io.gfile.isdir(existing_saved_modules_dir)):
for existing_saved_modules_dir in existing_saved_modules_dirs:
if existing_saved_modules_dir not in self._savedmodels:
tf.io.gfile.rmtree(existing_saved_modules_dir)
return checkpoint_path
def _get_existing_savedmodels(self):
"""Gets a list of all existing SavedModel paths in `directory`.
Returns:
A list of all existing SavedModel paths.
"""
saved_modules_glob = make_saved_modules_directory_name(
self._checkpoint_prefix + '-*')
return tf.io.gfile.glob(saved_modules_glob)
@property
def latest_savedmodel(self):
"""The path of the most recent SavedModel in `directory`.
Returns:
The latest SavedModel path. If there are no SavedModels, returns `None`.
"""
if self._savedmodels:
return self._savedmodels[-1]
return None
@property
def savedmodels(self):
"""A list of managed SavedModels.
Returns:
A list of SavedModel paths, sorted from oldest to newest.
"""
return self._savedmodels
......@@ -51,6 +51,9 @@ class CheckpointManagerTest(tf.test.TestCase):
first_path = manager.save()
second_path = manager.save()
savedmodel = savedmodel_checkpoint_manager.make_saved_modules_directory_name(
manager.latest_checkpoint)
self.assertEqual(savedmodel, manager.latest_savedmodel)
self.assertTrue(_models_exist(second_path, models.keys()))
self.assertFalse(_models_exist(first_path, models.keys()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment