Commit f3105295 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 414599442
parent eb501b56
...@@ -16,10 +16,13 @@ ...@@ -16,10 +16,13 @@
import abc import abc
import functools import functools
import time
from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.saved_model.model_utils import export_utils
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
class ExportModule(tf.Module, metaclass=abc.ABCMeta): class ExportModule(tf.Module, metaclass=abc.ABCMeta):
...@@ -119,15 +122,48 @@ def export(export_module: ExportModule, ...@@ -119,15 +122,48 @@ def export(export_module: ExportModule,
} }
else: else:
raise ValueError( raise ValueError(
"If the function_keys is a list, it must contain a single element. %s" 'If the function_keys is a list, it must contain a single element. %s'
% function_keys) % function_keys)
signatures = export_module.get_inference_signatures(function_keys) signatures = export_module.get_inference_signatures(function_keys)
if timestamped: if timestamped:
export_dir = export_utils.get_timestamped_export_dir( export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode(
export_savedmodel_dir).decode("utf-8") 'utf-8')
else: else:
export_dir = export_savedmodel_dir export_dir = export_savedmodel_dir
tf.saved_model.save( tf.saved_model.save(
export_module, export_dir, signatures=signatures, options=save_options) export_module, export_dir, signatures=signatures, options=save_options)
return export_dir return export_dir
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = tf.io.gfile.join(
tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp)))
if not tf.io.gfile.exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
raise RuntimeError('Failed to obtain a unique export directory name after '
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
...@@ -121,6 +121,13 @@ class ExportBaseTest(tf.test.TestCase): ...@@ -121,6 +121,13 @@ class ExportBaseTest(tf.test.TestCase):
output = module.serve(inputs) output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.11) self.assertAllClose(output['outputs'].numpy(), 1.11)
def test_get_timestamped_export_dir(self):
export_dir = self.get_temp_dir()
timed_dir = export_base.get_timestamped_export_dir(
export_dir_base=export_dir)
self.assertFalse(tf.io.gfile.exists(timed_dir))
self.assertIn(export_dir, str(timed_dir))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -13,17 +13,13 @@ ...@@ -13,17 +13,13 @@
# limitations under the License. # limitations under the License.
"""Common library to export a SavedModel from the export module.""" """Common library to export a SavedModel from the export module."""
import os
import time
from typing import Dict, List, Optional, Text, Union from typing import Dict, List, Optional, Text, Union
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import export_base from official.core import export_base
get_timestamped_export_dir = export_base.get_timestamped_export_dir
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
def export(export_module: export_base.ExportModule, def export(export_module: export_base.ExportModule,
...@@ -50,35 +46,3 @@ def export(export_module: export_base.ExportModule, ...@@ -50,35 +46,3 @@ def export(export_module: export_base.ExportModule,
}) })
return export_base.export(export_module, function_keys, export_savedmodel_dir, return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options) checkpoint_path, timestamped, save_options)
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = os.path.join(export_dir_base, str(timestamp))
if not tf.io.gfile.exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
raise RuntimeError('Failed to obtain a unique export directory name after '
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment