Commit 834ca16d authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 359198847
parent 2eb177da
...@@ -17,19 +17,29 @@ r"""Vision models export utility function for serving/inference.""" ...@@ -17,19 +17,29 @@ r"""Vision models export utility function for serving/inference."""
import os import os
from typing import Optional, List
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import train_utils from official.core import train_utils
from official.vision.beta import configs from official.vision.beta import configs
from official.vision.beta.serving import detection from official.vision.beta.serving import detection
from official.vision.beta.serving import export_base
from official.vision.beta.serving import image_classification from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation from official.vision.beta.serving import semantic_segmentation
def export_inference_graph(input_type, batch_size, input_image_size, params, def export_inference_graph(
checkpoint_path, export_dir, input_type: str,
export_checkpoint_subdir=None, batch_size: Optional[int],
export_saved_model_subdir=None): input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None):
"""Exports inference graph for the model specified in the exp config. """Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved Saved model is stored at export_dir/saved_model, checkpoint is saved
...@@ -42,6 +52,9 @@ def export_inference_graph(input_type, batch_size, input_image_size, params, ...@@ -42,6 +52,9 @@ def export_inference_graph(input_type, batch_size, input_image_size, params,
params: Experiment params. params: Experiment params.
checkpoint_path: Trained checkpoint path or directory. checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path. export_dir: Export directory path.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint. to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir export_saved_model_subdir: Optional subdirectory under export_dir
...@@ -60,18 +73,25 @@ def export_inference_graph(input_type, batch_size, input_image_size, params, ...@@ -60,18 +73,25 @@ def export_inference_graph(input_type, batch_size, input_image_size, params,
else: else:
output_saved_model_directory = export_dir output_saved_model_directory = export_dir
if not export_module:
if isinstance(params.task, if isinstance(params.task,
configs.image_classification.ImageClassificationTask): configs.image_classification.ImageClassificationTask):
export_module = image_classification.ClassificationModule( export_module = image_classification.ClassificationModule(
params=params, batch_size=batch_size, input_image_size=input_image_size) params=params,
batch_size=batch_size,
input_image_size=input_image_size)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance( elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask): params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule( export_module = detection.DetectionModule(
params=params, batch_size=batch_size, input_image_size=input_image_size) params=params,
batch_size=batch_size,
input_image_size=input_image_size)
elif isinstance(params.task, elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask): configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule( export_module = semantic_segmentation.SegmentationModule(
params=params, batch_size=batch_size, input_image_size=input_image_size) params=params,
batch_size=batch_size,
input_image_size=input_image_size)
else: else:
raise ValueError('Export module not implemented for {} task.'.format( raise ValueError('Export module not implemented for {} task.'.format(
type(params.task))) type(params.task)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment