Commit 3624730d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add checkpoint arg to export_inference_graph

PiperOrigin-RevId: 436288626
parent c7f443b7
......@@ -43,7 +43,8 @@ def export_inference_graph(
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False):
log_model_flops_and_params: bool = False,
checkpoint: Optional[tf.train.Checkpoint] = None):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
......@@ -67,6 +68,8 @@ def export_inference_graph(
save_options: `SaveOptions` for `tf.saved_model.save`.
log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
and model parameters to model_params.txt.
checkpoint: An optional tf.train.Checkpoint. If provided, the export module
will use it to read the weights.
"""
if export_checkpoint_subdir:
......@@ -123,6 +126,7 @@ def export_inference_graph(
export_module,
function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint=checkpoint,
checkpoint_path=checkpoint_path,
timestamped=False,
save_options=save_options)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment