Commit b38dd475 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 359198847
parent 89031e1a
......@@ -17,19 +17,29 @@ r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import train_utils
from official.vision.beta import configs
from official.vision.beta.serving import detection
from official.vision.beta.serving import export_base
from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation
def export_inference_graph(input_type, batch_size, input_image_size, params,
checkpoint_path, export_dir,
export_checkpoint_subdir=None,
export_saved_model_subdir=None):
def export_inference_graph(
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
......@@ -42,6 +52,9 @@ def export_inference_graph(input_type, batch_size, input_image_size, params,
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
......@@ -60,21 +73,28 @@ def export_inference_graph(input_type, batch_size, input_image_size, params,
else:
output_saved_model_directory = export_dir
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = image_classification.ClassificationModule(
params=params, batch_size=batch_size, input_image_size=input_image_size)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule(
params=params, batch_size=batch_size, input_image_size=input_image_size)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule(
params=params, batch_size=batch_size, input_image_size=input_image_size)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
if not export_module:
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = image_classification.ClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
model = export_module.build_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment