Commit 5b99c99c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 395996628
parent d3d4177d
...@@ -13,12 +13,19 @@ ...@@ -13,12 +13,19 @@
# limitations under the License. # limitations under the License.
"""Common library to export a SavedModel from the export module.""" """Common library to export a SavedModel from the export module."""
import os
import time
from typing import Dict, List, Optional, Text, Union from typing import Dict, List, Optional, Text, Union
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import export_base from official.core import export_base
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
def export(export_module: export_base.ExportModule, def export(export_module: export_base.ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]], function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text, export_savedmodel_dir: Text,
...@@ -39,7 +46,39 @@ def export(export_module: export_base.ExportModule, ...@@ -39,7 +46,39 @@ def export(export_module: export_base.ExportModule,
The savedmodel directory path. The savedmodel directory path.
""" """
save_options = tf.saved_model.SaveOptions(function_aliases={ save_options = tf.saved_model.SaveOptions(function_aliases={
"tpu_candidate": export_module.serve, 'tpu_candidate': export_module.serve,
}) })
return export_base.export(export_module, function_keys, export_savedmodel_dir, return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options) checkpoint_path, timestamped, save_options)
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = os.path.join(export_dir_base, str(timestamp))
if not tf.io.gfile.exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
raise RuntimeError('Failed to obtain a unique export directory name after '
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment