Commit 245ec883 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416088533
parent da2465fd
...@@ -92,7 +92,8 @@ def export(export_module: ExportModule, ...@@ -92,7 +92,8 @@ def export(export_module: ExportModule,
export_savedmodel_dir: Text, export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None, checkpoint_path: Optional[Text] = None,
timestamped: bool = True, timestamped: bool = True,
save_options: Optional[tf.saved_model.SaveOptions] = None) -> Text: save_options: Optional[tf.saved_model.SaveOptions] = None,
checkpoint: Optional[tf.train.Checkpoint] = None) -> Text:
"""Exports to SavedModel format. """Exports to SavedModel format.
Args: Args:
...@@ -104,6 +105,8 @@ def export(export_module: ExportModule, ...@@ -104,6 +105,8 @@ def export(export_module: ExportModule,
checkpoint_path: Object-based checkpoint path or directory. checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory. timestamped: Whether to export the savedmodel to a timestamped directory.
save_options: `SaveOptions` for `tf.saved_model.save`. save_options: `SaveOptions` for `tf.saved_model.save`.
checkpoint: An optional tf.train.Checkpoint. If provided, the export module
will use it to read the weights.
Returns: Returns:
The savedmodel directory path. The savedmodel directory path.
...@@ -112,7 +115,8 @@ def export(export_module: ExportModule, ...@@ -112,7 +115,8 @@ def export(export_module: ExportModule,
if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file): if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if ckpt_dir_or_file: if ckpt_dir_or_file:
checkpoint = tf.train.Checkpoint(model=export_module.model) if checkpoint is None:
checkpoint = tf.train.Checkpoint(model=export_module.model)
checkpoint.read( checkpoint.read(
ckpt_dir_or_file).assert_existing_objects_matched().expect_partial() ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
if isinstance(function_keys, list): if isinstance(function_keys, list):
......
...@@ -60,6 +60,10 @@ def define_flags(): ...@@ -60,6 +60,10 @@ def define_flags():
flags.DEFINE_string( flags.DEFINE_string(
"function_keys", None, "function_keys", None,
"A string key to retrieve pre-defined serving signatures.") "A string key to retrieve pre-defined serving signatures.")
flags.DEFINE_string(
"module_key", None,
"For multi-task case, load the export module weights from a specific "
"checkpoint item.")
flags.DEFINE_bool("convert_tpu", False, "") flags.DEFINE_bool("convert_tpu", False, "")
flags.DEFINE_multi_integer("allowed_batch_size", None, flags.DEFINE_multi_integer("allowed_batch_size", None,
"Allowed batch sizes for batching ops.") "Allowed batch sizes for batching ops.")
...@@ -116,7 +120,8 @@ def main(_): ...@@ -116,7 +120,8 @@ def main(_):
export_module, export_module,
function_keys=[FLAGS.function_keys], function_keys=[FLAGS.function_keys],
checkpoint_path=FLAGS.checkpoint_path, checkpoint_path=FLAGS.checkpoint_path,
export_savedmodel_dir=FLAGS.export_savedmodel_dir) export_savedmodel_dir=FLAGS.export_savedmodel_dir,
module_key=FLAGS.module_key)
if FLAGS.convert_tpu: if FLAGS.convert_tpu:
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
......
...@@ -26,7 +26,8 @@ def export(export_module: export_base.ExportModule, ...@@ -26,7 +26,8 @@ def export(export_module: export_base.ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]], function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text, export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None, checkpoint_path: Optional[Text] = None,
timestamped: bool = True) -> Text: timestamped: bool = True,
module_key: Optional[Text] = None) -> Text:
"""Exports to SavedModel format. """Exports to SavedModel format.
Args: Args:
...@@ -37,6 +38,8 @@ def export(export_module: export_base.ExportModule, ...@@ -37,6 +38,8 @@ def export(export_module: export_base.ExportModule,
export_savedmodel_dir: Output saved model directory. export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory. checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory. timestamped: Whether to export the savedmodel to a timestamped directory.
module_key: Optional string to identify a checkpoint object to load for the
model in the export module.
Returns: Returns:
The savedmodel directory path. The savedmodel directory path.
...@@ -44,5 +47,16 @@ def export(export_module: export_base.ExportModule, ...@@ -44,5 +47,16 @@ def export(export_module: export_base.ExportModule,
save_options = tf.saved_model.SaveOptions(function_aliases={ save_options = tf.saved_model.SaveOptions(function_aliases={
'tpu_candidate': export_module.serve, 'tpu_candidate': export_module.serve,
}) })
return export_base.export(export_module, function_keys, export_savedmodel_dir, if module_key:
checkpoint_path, timestamped, save_options) kwargs = {module_key: export_module.model}
checkpoint = tf.train.Checkpoint(**kwargs)
else:
checkpoint = None
return export_base.export(
export_module,
function_keys,
export_savedmodel_dir,
checkpoint_path,
timestamped,
save_options,
checkpoint=checkpoint)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment