Commit b38dd475 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 359198847
parent 89031e1a
...@@ -17,19 +17,29 @@ r"""Vision models export utility function for serving/inference.""" ...@@ -17,19 +17,29 @@ r"""Vision models export utility function for serving/inference."""
import os import os
from typing import Optional, List
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import train_utils from official.core import train_utils
from official.vision.beta import configs from official.vision.beta import configs
from official.vision.beta.serving import detection from official.vision.beta.serving import detection
from official.vision.beta.serving import export_base
from official.vision.beta.serving import image_classification from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation from official.vision.beta.serving import semantic_segmentation
def export_inference_graph(input_type, batch_size, input_image_size, params, def export_inference_graph(
checkpoint_path, export_dir, input_type: str,
export_checkpoint_subdir=None, batch_size: Optional[int],
export_saved_model_subdir=None): input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None):
"""Exports inference graph for the model specified in the exp config. """Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved Saved model is stored at export_dir/saved_model, checkpoint is saved
...@@ -42,6 +52,9 @@ def export_inference_graph(input_type, batch_size, input_image_size, params, ...@@ -42,6 +52,9 @@ def export_inference_graph(input_type, batch_size, input_image_size, params,
params: Experiment params. params: Experiment params.
checkpoint_path: Trained checkpoint path or directory. checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path. export_dir: Export directory path.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint. to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir export_saved_model_subdir: Optional subdirectory under export_dir
...@@ -60,21 +73,28 @@ def export_inference_graph(input_type, batch_size, input_image_size, params, ...@@ -60,21 +73,28 @@ def export_inference_graph(input_type, batch_size, input_image_size, params,
else: else:
output_saved_model_directory = export_dir output_saved_model_directory = export_dir
if isinstance(params.task, if not export_module:
configs.image_classification.ImageClassificationTask): if isinstance(params.task,
export_module = image_classification.ClassificationModule( configs.image_classification.ImageClassificationTask):
params=params, batch_size=batch_size, input_image_size=input_image_size) export_module = image_classification.ClassificationModule(
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance( params=params,
params.task, configs.maskrcnn.MaskRCNNTask): batch_size=batch_size,
export_module = detection.DetectionModule( input_image_size=input_image_size)
params=params, batch_size=batch_size, input_image_size=input_image_size) elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
elif isinstance(params.task, params.task, configs.maskrcnn.MaskRCNNTask):
configs.semantic_segmentation.SemanticSegmentationTask): export_module = detection.DetectionModule(
export_module = semantic_segmentation.SegmentationModule( params=params,
params=params, batch_size=batch_size, input_image_size=input_image_size) batch_size=batch_size,
else: input_image_size=input_image_size)
raise ValueError('Export module not implemented for {} task.'.format( elif isinstance(params.task,
type(params.task))) configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
model = export_module.build_model() model = export_module.build_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment