"examples/mxnet/vscode:/vscode.git/clone" did not exist on "993fd3f94baff137216d2b16dade638f3b6c99c3"
Commit b729e4ec authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 448644197
parent bce895d9
...@@ -61,6 +61,7 @@ class ExportModule(export_base.ExportModule): ...@@ -61,6 +61,7 @@ class ExportModule(export_base.ExportModule):
preprocessor=preprocessor, preprocessor=preprocessor,
inference_step=inference_step, inference_step=inference_step,
postprocessor=postprocessor) postprocessor=postprocessor)
self.eval_postprocessor = eval_postprocessor
self.input_signature = input_signature self.input_signature = input_signature
@tf.function @tf.function
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
r"""Vision models export utility function for serving/inference.""" r"""Vision models export utility function for serving/inference."""
import os import os
from typing import Optional, List from typing import Optional, List, Union, Text, Dict
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -44,7 +44,8 @@ def export_inference_graph( ...@@ -44,7 +44,8 @@ def export_inference_graph(
save_options: Optional[tf.saved_model.SaveOptions] = None, save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False, log_model_flops_and_params: bool = False,
checkpoint: Optional[tf.train.Checkpoint] = None, checkpoint: Optional[tf.train.Checkpoint] = None,
input_name: Optional[str] = None): input_name: Optional[str] = None,
function_keys: Optional[Union[List[Text], Dict[Text, Text]]] = None,):
"""Exports inference graph for the model specified in the exp config. """Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved Saved model is stored at export_dir/saved_model, checkpoint is saved
...@@ -72,6 +73,9 @@ def export_inference_graph( ...@@ -72,6 +73,9 @@ def export_inference_graph(
will use it to read the weights. will use it to read the weights.
input_name: The input tensor name, default at `None` which produces input input_name: The input tensor name, default at `None` which produces input
tensor name `inputs`. tensor name `inputs`.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
""" """
if export_checkpoint_subdir: if export_checkpoint_subdir:
...@@ -130,7 +134,7 @@ def export_inference_graph( ...@@ -130,7 +134,7 @@ def export_inference_graph(
export_base.export( export_base.export(
export_module, export_module,
function_keys=[input_type], function_keys=function_keys if function_keys else [input_type],
export_savedmodel_dir=output_saved_model_directory, export_savedmodel_dir=output_saved_model_directory,
checkpoint=checkpoint, checkpoint=checkpoint,
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment