Commit e876cb38 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Control model's input size by parameters instead of hardcoding it in export library.

PiperOrigin-RevId: 359409792
parent ef6d5a73
...@@ -16,33 +16,30 @@ ...@@ -16,33 +16,30 @@
"""Base class for model export.""" """Base class for model export."""
import abc import abc
import tensorflow as tf from typing import Optional, Sequence, Mapping
def _decode_image(encoded_image_bytes):
image_tensor = tf.image.decode_image(encoded_image_bytes, channels=3)
image_tensor.set_shape((None, None, 3))
return image_tensor
import tensorflow as tf
def _decode_tf_example(tf_example_string_tensor): from official.modeling.hyperparams import config_definitions as cfg
keys_to_features = {'image/encoded': tf.io.FixedLenFeature((), tf.string)}
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=keys_to_features)
image_tensor = _decode_image(parsed_tensors['image/encoded'])
return image_tensor
class ExportModule(tf.Module, metaclass=abc.ABCMeta): class ExportModule(tf.Module, metaclass=abc.ABCMeta):
"""Base Export Module.""" """Base Export Module."""
def __init__(self, params, batch_size, input_image_size, model=None): def __init__(self,
params: cfg.ExperimentConfig,
batch_size: int,
input_image_size: Sequence[int],
num_channels: int = 3,
model: Optional[tf.keras.Model] = None):
"""Initializes a module for export. """Initializes a module for export.
Args: Args:
params: Experiment params. params: Experiment params.
batch_size: Int or None. batch_size: The batch size of the model input. Can be `int` or None.
input_image_size: List or Tuple of height, width of the input image. input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width].
num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported. model: A tf.keras.Model instance to be exported.
""" """
...@@ -50,48 +47,98 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -50,48 +47,98 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
self._params = params self._params = params
self._batch_size = batch_size self._batch_size = batch_size
self._input_image_size = input_image_size self._input_image_size = input_image_size
self._num_channels = num_channels
self._model = model self._model = model
def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
Returns:
A decoded image tensor.
"""
if len(self._input_image_size) == 2:
# Decode an image if 2D input is expected.
image_tensor = tf.image.decode_image(
encoded_image_bytes, channels=self._num_channels)
image_tensor.set_shape((None, None, self._num_channels))
else:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor = tf.io.decode_raw(encoded_image_bytes, out_type=tf.uint8)
image_tensor = tf.reshape(image_tensor,
self._input_image_size + [self._num_channels])
return image_tensor
def _decode_tf_example(
self, tf_example_string_tensor: tf.train.Example) -> tf.Tensor:
"""Decodes a TF Example to an image tensor.
Args:
tf_example_string_tensor: A tf.train.Example of encoded image and other
information.
Returns:
A decoded image tensor.
"""
keys_to_features = {'image/encoded': tf.io.FixedLenFeature((), tf.string)}
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=keys_to_features)
image_tensor = self._decode_image(parsed_tensors['image/encoded'])
return image_tensor
@abc.abstractmethod @abc.abstractmethod
def build_model(self, **kwargs): def build_model(self, **kwargs):
"""Builds model and sets self._model.""" """Builds model and sets self._model."""
@abc.abstractmethod @abc.abstractmethod
def _run_inference_on_image_tensors(self, images): def _run_inference_on_image_tensors(
self, images: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Runs inference on images.""" """Runs inference on images."""
@tf.function @tf.function
def inference_from_image_tensors(self, input_tensor): def inference_from_image_tensors(
self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self._run_inference_on_image_tensors(input_tensor) return self._run_inference_on_image_tensors(input_tensor)
@tf.function @tf.function
def inference_from_image_bytes(self, input_tensor): def inference_from_image_bytes(self, input_tensor: str):
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
_decode_image, self._decode_image,
elems=input_tensor, elems=input_tensor,
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[None, None, 3], dtype=tf.uint8), shape=[None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8),
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
return self._run_inference_on_image_tensors(images) return self._run_inference_on_image_tensors(images)
@tf.function @tf.function
def inference_from_tf_example(self, input_tensor): def inference_from_tf_example(
self, input_tensor: tf.train.Example) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
_decode_tf_example, self._decode_tf_example,
elems=input_tensor, elems=input_tensor,
# Height/width of the shape of input images is unspecified (None) # Height/width of the shape of input images is unspecified (None)
# at the time of decoding the example, but the shape will # at the time of decoding the example, but the shape will
# be adjusted to conform to the input layer of the model, # be adjusted to conform to the input layer of the model,
# by _run_inference_on_image_tensors() below. # by _run_inference_on_image_tensors() below.
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[None, None, 3], dtype=tf.uint8), shape=[None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8),
dtype=tf.uint8, dtype=tf.uint8,
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
......
...@@ -37,6 +37,7 @@ def export_inference_graph( ...@@ -37,6 +37,7 @@ def export_inference_graph(
params: cfg.ExperimentConfig, params: cfg.ExperimentConfig,
checkpoint_path: str, checkpoint_path: str,
export_dir: str, export_dir: str,
num_channels: Optional[int] = 3,
export_module: Optional[export_base.ExportModule] = None, export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None, export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None): export_saved_model_subdir: Optional[str] = None):
...@@ -52,6 +53,7 @@ def export_inference_graph( ...@@ -52,6 +53,7 @@ def export_inference_graph(
params: Experiment params. params: Experiment params.
checkpoint_path: Trained checkpoint path or directory. checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path. export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export to create one. If None, the params will be used to create an export
module. module.
...@@ -79,19 +81,22 @@ def export_inference_graph( ...@@ -79,19 +81,22 @@ def export_inference_graph(
export_module = image_classification.ClassificationModule( export_module = image_classification.ClassificationModule(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size) input_image_size=input_image_size,
num_channels=num_channels)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance( elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask): params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule( export_module = detection.DetectionModule(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size) input_image_size=input_image_size,
num_channels=num_channels)
elif isinstance(params.task, elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask): configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule( export_module = semantic_segmentation.SegmentationModule(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size) input_image_size=input_image_size,
num_channels=num_channels)
else: else:
raise ValueError('Export module not implemented for {} task.'.format( raise ValueError('Export module not implemented for {} task.'.format(
type(params.task))) type(params.task)))
...@@ -107,7 +112,7 @@ def export_inference_graph( ...@@ -107,7 +112,7 @@ def export_inference_graph(
if input_type == 'image_tensor': if input_type == 'image_tensor':
input_signature = tf.TensorSpec( input_signature = tf.TensorSpec(
shape=[batch_size, None, None, 3], shape=[batch_size] + [None] * len(input_image_size) + [num_channels],
dtype=tf.uint8) dtype=tf.uint8)
signatures = { signatures = {
'serving_default': 'serving_default':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment