Commit a7ba08aa authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 447598773
parent 930abe21
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
"""Factory for YOLO export modules.""" """Factory for YOLO export modules."""
from typing import List, Optional from typing import Any, Callable, Dict, List, Optional, Text, Union
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import export_base
from official.vision import configs from official.vision import configs
from official.vision.beta.projects.yolo.configs.yolo import YoloTask from official.vision.beta.projects.yolo.configs.yolo import YoloTask
from official.vision.beta.projects.yolo.modeling import factory as yolo_factory from official.vision.beta.projects.yolo.modeling import factory as yolo_factory
...@@ -27,16 +28,84 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder # ...@@ -27,16 +28,84 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder #
from official.vision.beta.projects.yolo.serving import model_fn as yolo_model_fn from official.vision.beta.projects.yolo.serving import model_fn as yolo_model_fn
from official.vision.dataloaders import classification_input from official.vision.dataloaders import classification_input
from official.vision.modeling import factory from official.vision.modeling import factory
from official.vision.serving import export_base_v2 as export_base
from official.vision.serving import export_utils from official.vision.serving import export_utils
class ExportModule(export_base.ExportModule):
"""Base Export Module."""
def __init__(self,
params: cfg.ExperimentConfig,
model: tf.keras.Model,
input_signature: Union[tf.TensorSpec, Dict[str, tf.TensorSpec]],
preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None,
eval_postprocessor: Optional[Callable[..., Any]] = None):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable to postprocess the model outputs.
eval_postprocessor: An optional callable to postprocess model outputs
used for model evaluation.
"""
super().__init__(
params,
model=model,
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.input_signature = input_signature
@tf.function
def serve(self, inputs: Any) -> Any:
x = self.preprocessor(inputs=inputs) if self.preprocessor else inputs
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return x
@tf.function
def serve_eval(self, inputs: Any) -> Any:
x = self.preprocessor(inputs=inputs) if self.preprocessor else inputs
x = self.inference_step(x)
x = self.eval_postprocessor(x) if self.eval_postprocessor else x
return x
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> dict[Text, Any]:
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for _, def_name in function_keys.items():
if 'eval' in def_name and self.eval_postprocessor:
signatures[def_name] = self.serve_eval.get_concrete_function(
self.input_signature)
else:
signatures[def_name] = self.serve.get_concrete_function(
self.input_signature)
return signatures
def create_classification_export_module( def create_classification_export_module(
params: cfg.ExperimentConfig, params: cfg.ExperimentConfig,
input_type: str, input_type: str,
batch_size: int, batch_size: int,
input_image_size: List[int], input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule: num_channels: int = 3) -> ExportModule:
"""Creates classification export module.""" """Creates classification export module."""
input_signature = export_utils.get_image_input_signatures( input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels) input_type, batch_size, input_image_size, num_channels)
...@@ -71,7 +140,7 @@ def create_classification_export_module( ...@@ -71,7 +140,7 @@ def create_classification_export_module(
probs = tf.nn.softmax(logits) probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs} return {'logits': logits, 'probs': probs}
export_module = export_base.ExportModule( export_module = ExportModule(
params, params,
model=model, model=model,
input_signature=input_signature, input_signature=input_signature,
...@@ -85,7 +154,7 @@ def create_yolo_export_module( ...@@ -85,7 +154,7 @@ def create_yolo_export_module(
input_type: str, input_type: str,
batch_size: int, batch_size: int,
input_image_size: List[int], input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule: num_channels: int = 3) -> ExportModule:
"""Creates YOLO export module.""" """Creates YOLO export module."""
input_signature = export_utils.get_image_input_signatures( input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels) input_type, batch_size, input_image_size, num_channels)
...@@ -144,7 +213,7 @@ def create_yolo_export_module( ...@@ -144,7 +213,7 @@ def create_yolo_export_module(
return final_outputs return final_outputs
export_module = export_base.ExportModule( export_module = ExportModule(
params, params,
model=model, model=model,
input_signature=input_signature, input_signature=input_signature,
...@@ -158,7 +227,7 @@ def get_export_module(params: cfg.ExperimentConfig, ...@@ -158,7 +227,7 @@ def get_export_module(params: cfg.ExperimentConfig,
input_type: str, input_type: str,
batch_size: Optional[int], batch_size: Optional[int],
input_image_size: List[int], input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule: num_channels: int = 3) -> ExportModule:
"""Factory for export modules.""" """Factory for export modules."""
if isinstance(params.task, if isinstance(params.task,
configs.image_classification.ImageClassificationTask): configs.image_classification.ImageClassificationTask):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment