Commit eb6e0ac4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 447598773
parent d801c959
......@@ -14,11 +14,12 @@
"""Factory for YOLO export modules."""
from typing import List, Optional
from typing import Any, Callable, Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.vision import configs
from official.vision.beta.projects.yolo.configs.yolo import YoloTask
from official.vision.beta.projects.yolo.modeling import factory as yolo_factory
......@@ -27,16 +28,84 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder #
from official.vision.beta.projects.yolo.serving import model_fn as yolo_model_fn
from official.vision.dataloaders import classification_input
from official.vision.modeling import factory
from official.vision.serving import export_base_v2 as export_base
from official.vision.serving import export_utils
class ExportModule(export_base.ExportModule):
"""Base Export Module."""
def __init__(self,
params: cfg.ExperimentConfig,
model: tf.keras.Model,
input_signature: Union[tf.TensorSpec, Dict[str, tf.TensorSpec]],
preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None,
eval_postprocessor: Optional[Callable[..., Any]] = None):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable to postprocess the model outputs.
eval_postprocessor: An optional callable to postprocess model outputs
used for model evaluation.
"""
super().__init__(
params,
model=model,
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.input_signature = input_signature
@tf.function
def serve(self, inputs: Any) -> Any:
x = self.preprocessor(inputs=inputs) if self.preprocessor else inputs
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return x
@tf.function
def serve_eval(self, inputs: Any) -> Any:
x = self.preprocessor(inputs=inputs) if self.preprocessor else inputs
x = self.inference_step(x)
x = self.eval_postprocessor(x) if self.eval_postprocessor else x
return x
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> dict[Text, Any]:
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for _, def_name in function_keys.items():
if 'eval' in def_name and self.eval_postprocessor:
signatures[def_name] = self.serve_eval.get_concrete_function(
self.input_signature)
else:
signatures[def_name] = self.serve.get_concrete_function(
self.input_signature)
return signatures
def create_classification_export_module(
params: cfg.ExperimentConfig,
input_type: str,
batch_size: int,
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
num_channels: int = 3) -> ExportModule:
"""Creates classification export module."""
input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels)
......@@ -71,7 +140,7 @@ def create_classification_export_module(
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
export_module = export_base.ExportModule(
export_module = ExportModule(
params,
model=model,
input_signature=input_signature,
......@@ -85,7 +154,7 @@ def create_yolo_export_module(
input_type: str,
batch_size: int,
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
num_channels: int = 3) -> ExportModule:
"""Creates YOLO export module."""
input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels)
......@@ -144,7 +213,7 @@ def create_yolo_export_module(
return final_outputs
export_module = export_base.ExportModule(
export_module = ExportModule(
params,
model=model,
input_signature=input_signature,
......@@ -158,7 +227,7 @@ def get_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
num_channels: int = 3) -> ExportModule:
"""Factory for export modules."""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment