Commit f3105295 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 414599442
parent eb501b56
......@@ -16,10 +16,13 @@
import abc
import functools
import time
from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
from absl import logging
import tensorflow as tf
from tensorflow.python.saved_model.model_utils import export_utils
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
class ExportModule(tf.Module, metaclass=abc.ABCMeta):
......@@ -119,15 +122,48 @@ def export(export_module: ExportModule,
}
else:
raise ValueError(
"If the function_keys is a list, it must contain a single element. %s"
'If the function_keys is a list, it must contain a single element. %s'
% function_keys)
signatures = export_module.get_inference_signatures(function_keys)
if timestamped:
export_dir = export_utils.get_timestamped_export_dir(
export_savedmodel_dir).decode("utf-8")
export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode(
'utf-8')
else:
export_dir = export_savedmodel_dir
tf.saved_model.save(
export_module, export_dir, signatures=signatures, options=save_options)
return export_dir
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = tf.io.gfile.join(
tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp)))
if not tf.io.gfile.exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
raise RuntimeError('Failed to obtain a unique export directory name after '
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
......@@ -121,6 +121,13 @@ class ExportBaseTest(tf.test.TestCase):
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.11)
def test_get_timestamped_export_dir(self):
export_dir = self.get_temp_dir()
timed_dir = export_base.get_timestamped_export_dir(
export_dir_base=export_dir)
self.assertFalse(tf.io.gfile.exists(timed_dir))
self.assertIn(export_dir, str(timed_dir))
if __name__ == '__main__':
tf.test.main()
......@@ -13,17 +13,13 @@
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
import os
import time
from typing import Dict, List, Optional, Text, Union
from absl import logging
import tensorflow as tf
from official.core import export_base
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
get_timestamped_export_dir = export_base.get_timestamped_export_dir
def export(export_module: export_base.ExportModule,
......@@ -50,35 +46,3 @@ def export(export_module: export_base.ExportModule,
})
return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options)
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = os.path.join(export_dir_base, str(timestamp))
if not tf.io.gfile.exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
raise RuntimeError('Failed to obtain a unique export directory name after '
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment