Commit 245ec883 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416088533
parent da2465fd
......@@ -92,7 +92,8 @@ def export(export_module: ExportModule,
export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None,
timestamped: bool = True,
save_options: Optional[tf.saved_model.SaveOptions] = None) -> Text:
save_options: Optional[tf.saved_model.SaveOptions] = None,
checkpoint: Optional[tf.train.Checkpoint] = None) -> Text:
"""Exports to SavedModel format.
Args:
......@@ -104,6 +105,8 @@ def export(export_module: ExportModule,
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
save_options: `SaveOptions` for `tf.saved_model.save`.
checkpoint: An optional tf.train.Checkpoint. If provided, the export module
will use it to read the weights.
Returns:
The savedmodel directory path.
......@@ -112,6 +115,7 @@ def export(export_module: ExportModule,
if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if ckpt_dir_or_file:
if checkpoint is None:
checkpoint = tf.train.Checkpoint(model=export_module.model)
checkpoint.read(
ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
......
......@@ -60,6 +60,10 @@ def define_flags():
flags.DEFINE_string(
"function_keys", None,
"A string key to retrieve pre-defined serving signatures.")
flags.DEFINE_string(
"module_key", None,
"For multi-task case, load the export module weights from a specific "
"checkpoint item.")
flags.DEFINE_bool("convert_tpu", False, "")
flags.DEFINE_multi_integer("allowed_batch_size", None,
"Allowed batch sizes for batching ops.")
......@@ -116,7 +120,8 @@ def main(_):
export_module,
function_keys=[FLAGS.function_keys],
checkpoint_path=FLAGS.checkpoint_path,
export_savedmodel_dir=FLAGS.export_savedmodel_dir)
export_savedmodel_dir=FLAGS.export_savedmodel_dir,
module_key=FLAGS.module_key)
if FLAGS.convert_tpu:
# pylint: disable=g-import-not-at-top
......
......@@ -26,7 +26,8 @@ def export(export_module: export_base.ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None,
timestamped: bool = True) -> Text:
timestamped: bool = True,
module_key: Optional[Text] = None) -> Text:
"""Exports to SavedModel format.
Args:
......@@ -37,6 +38,8 @@ def export(export_module: export_base.ExportModule,
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
module_key: Optional string to identify a checkpoint object to load for the
model in the export module.
Returns:
The savedmodel directory path.
......@@ -44,5 +47,16 @@ def export(export_module: export_base.ExportModule,
save_options = tf.saved_model.SaveOptions(function_aliases={
'tpu_candidate': export_module.serve,
})
return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options)
if module_key:
kwargs = {module_key: export_module.model}
checkpoint = tf.train.Checkpoint(**kwargs)
else:
checkpoint = None
return export_base.export(
export_module,
function_keys,
export_savedmodel_dir,
checkpoint_path,
timestamped,
save_options,
checkpoint=checkpoint)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment