"vscode:/vscode.git/clone" did not exist on "307489870512e9af61718cf218a653f876449e14"
Commit 744f765e authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 404833082
parent 14432672
...@@ -253,3 +253,21 @@ class Parser(parser.Parser): ...@@ -253,3 +253,21 @@ class Parser(parser.Parser):
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
return image return image
@classmethod
def inference_fn(cls,
image: tf.Tensor,
input_image_size: List[int],
num_channels: int = 3) -> tf.Tensor:
"""Builds image model inputs for serving."""
image = tf.cast(image, dtype=tf.float32)
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize(
image, input_image_size, method=tf.image.ResizeMethod.BILINEAR)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(
image, offset=MEAN_RGB, scale=STDDEV_RGB)
image.set_shape(input_image_size + [num_channels])
return image
...@@ -67,3 +67,15 @@ class Parser(object): ...@@ -67,3 +67,15 @@ class Parser(object):
return self._parse_eval_data(decoded_tensors) return self._parse_eval_data(decoded_tensors)
return parse return parse
@classmethod
def inference_fn(cls, inputs):
"""Parses inputs for predictions.
Args:
inputs: A Tensor, or dictionary of Tensors.
Returns:
processed_inputs: An input tensor to the model.
"""
pass
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
from typing import Dict, Optional, Text, Callable, Any, Union
import tensorflow as tf
from official.core import export_base
class ExportModule(export_base.ExportModule):
"""Base Export Module."""
def __init__(self,
params,
model: tf.keras.Model,
input_signature: Union[tf.TensorSpec, Dict[str, tf.TensorSpec]],
preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable function to preprocess the inputs.
inference_step: An optional callable function to forward-pass the model.
postprocessor: An optional callable function to postprocess the model
outputs.
"""
super().__init__(params, model=model, inference_step=inference_step)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.input_signature = input_signature
@tf.function
def serve(self, inputs):
x = self.preprocessor(inputs=inputs) if self.preprocessor else inputs
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return x
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for _, def_name in function_keys.items():
signatures[def_name] = self.serve.get_concrete_function(
self.input_signature)
return signatures
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base_v2."""
import os
import tensorflow as tf
from official.core import export_base
from official.vision.beta.serving import export_base_v2
class TestModel(tf.keras.Model):
def __init__(self):
super().__init__()
self._dense = tf.keras.layers.Dense(2)
def call(self, inputs):
return {'outputs': self._dense(inputs)}
class ExportBaseTest(tf.test.TestCase):
def test_preprocessor(self):
tmp_dir = self.get_temp_dir()
model = TestModel()
inputs = tf.ones([2, 4], tf.float32)
preprocess_fn = lambda inputs: 2 * inputs
module = export_base_v2.ExportModule(
params=None,
input_signature=tf.TensorSpec(shape=[2, 4]),
model=model,
preprocessor=preprocess_fn)
expected_output = model(preprocess_fn(inputs))
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['serving_default'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['serving_default'](inputs)
print('output', output)
self.assertAllClose(
output['outputs'].numpy(), expected_output['outputs'].numpy())
def test_postprocessor(self):
tmp_dir = self.get_temp_dir()
model = TestModel()
inputs = tf.ones([2, 4], tf.float32)
postprocess_fn = lambda logits: {'outputs': 2 * logits['outputs']}
module = export_base_v2.ExportModule(
params=None,
model=model,
input_signature=tf.TensorSpec(shape=[2, 4]),
postprocessor=postprocess_fn)
expected_output = postprocess_fn(model(inputs))
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['serving_default'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['serving_default'](inputs)
self.assertAllClose(
output['outputs'].numpy(), expected_output['outputs'].numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for vision export modules."""
from typing import List, Optional
import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision.beta import configs
from official.vision.beta.dataloaders import classification_input
from official.vision.beta.modeling import factory
from official.vision.beta.serving import export_base_v2 as export_base
from official.vision.beta.serving import export_utils
def create_classification_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: int,
input_image_size: List[int],
num_channels: int = 3):
"""Creats classification export module."""
input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels)
input_specs = tf.keras.layers.InputSpec(
shape=[batch_size] + input_image_size + [num_channels])
model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None)
def preprocess_fn(inputs):
image_tensor = export_utils.parse_image(
inputs, input_type, input_image_size, num_channels)
def preprocess_image_fn(inputs):
return classification_input.Parser.inference_fn(
inputs, input_image_size, num_channels)
images = tf.map_fn(
preprocess_image_fn, elems=image_tensor,
fn_output_signature=tf.TensorSpec(
shape=input_image_size + [num_channels],
dtype=tf.float32))
return images
def postprocess_fn(logits):
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
export_module = export_base.ExportModule(params,
model=model,
input_signature=input_signature,
preprocessor=preprocess_fn,
postprocessor=postprocess_fn)
return export_module
def get_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
"""Factory for export modules."""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = create_classification_export_module(
params, input_type, batch_size, input_image_size, num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
return export_module
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for vision modules."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.core import export_base
from official.vision.beta.dataloaders import classification_input
from official.vision.beta.serving import export_module_factory
class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self, input_type, input_image_size):
params = exp_factory.get_exp_config('resnet_imagenet')
params.task.model.backbone.resnet.model_id = 18
module = export_module_factory.create_classification_export_module(
params, input_type, batch_size=1, input_image_size=input_image_size)
return module
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
return tf.zeros((1, 32, 32, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((32, 32, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type='image_tensor'):
input_image_size = [32, 32]
tmp_dir = self.get_temp_dir()
module = self._get_classification_module(input_type, input_image_size)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,))
ckpt_path = tf.train.Checkpoint(model=module.model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, [input_type],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
imported = tf.saved_model.load(export_dir)
classification_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
def preprocess_image_fn(inputs):
return classification_input.Parser.inference_fn(
inputs, input_image_size, num_channels=3)
processed_images = tf.map_fn(
preprocess_image_fn,
elems=tf.zeros([1] + input_image_size + [3], dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=input_image_size + [3], dtype=tf.float32))
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertTrue(hasattr(imported.model, 'test_trackable'))
self.assertAllClose(out['logits'].numpy(), expected_logits.numpy())
self.assertAllClose(out['probs'].numpy(), expected_prob.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision.beta.serving import export_module_factory
def export(
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
num_channels: Optional[int] = 3,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None):
"""Exports the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
"""
if export_checkpoint_subdir:
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
export_dir, export_saved_model_subdir)
else:
output_saved_model_directory = export_dir
export_module = export_module_factory.get_export_module(
params,
input_type=input_type,
batch_size=batch_size,
input_image_size=input_image_size,
num_channels=num_channels)
export_base.export(
export_module,
function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint_path=checkpoint_path,
timestamped=False,
save_options=save_options)
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper utils for export library."""
from typing import List, Optional
import tensorflow as tf
# pylint: disable=g-long-lambda
def get_image_input_signatures(input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
num_channels: int = 3):
"""Gets input signatures for an image.
Args:
input_type: A `str`, can be either tf_example, image_bytes, or image_tensor.
batch_size: `int` for batch size or None.
input_image_size: List[int] for the height and width of the input image.
num_channels: `int` for number of channels in the input image.
Returns:
tf.TensorSpec of the input tensor.
"""
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[batch_size] + [None] * len(input_image_size) + [num_channels],
dtype=tf.uint8)
elif input_type in ['image_bytes', 'serve_examples', 'tf_example']:
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
elif input_type == 'tflite':
input_signature = tf.TensorSpec(
shape=[1] + input_image_size + [num_channels], dtype=tf.float32)
else:
raise ValueError('Unrecognized `input_type`')
return input_signature
def decode_image(encoded_image_bytes: str,
input_image_size: List[int],
num_channels: int = 3,) -> tf.Tensor:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
input_image_size: List[int] for the desired input size. This will be used to
infer whether the image is 2d or 3d.
num_channels: `int` for number of image channels.
Returns:
A decoded image tensor.
"""
if len(input_image_size) == 2:
# Decode an image if 2D input is expected.
image_tensor = tf.image.decode_image(
encoded_image_bytes, channels=num_channels)
else:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor = tf.io.decode_raw(encoded_image_bytes, out_type=tf.uint8)
image_tensor.set_shape([None] * len(input_image_size) + [num_channels])
return image_tensor
def decode_image_tf_example(tf_example_string_tensor: tf.train.Example,
input_image_size: List[int],
num_channels: int = 3,
encoded_key: str = 'image/encoded'
) -> tf.Tensor:
"""Decodes a TF Example to an image tensor."""
keys_to_features = {
encoded_key: tf.io.FixedLenFeature((), tf.string, default_value=''),
}
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=keys_to_features)
image_tensor = decode_image(
parsed_tensors[encoded_key],
input_image_size=input_image_size,
num_channels=num_channels)
return image_tensor
def parse_image(
inputs, input_type: str, input_image_size: List[int], num_channels: int):
"""Parses image."""
if input_type in ['tf_example', 'serve_examples']:
decode_image_tf_example_fn = (
lambda x: decode_image_tf_example(x, input_image_size, num_channels))
image_tensor = tf.map_fn(
decode_image_tf_example_fn,
elems=inputs,
fn_output_signature=tf.TensorSpec(
shape=[None] * len(input_image_size) + [num_channels],
dtype=tf.uint8),
)
elif input_type == 'image_bytes':
decode_image_fn = lambda x: decode_image(x, input_image_size, num_channels)
image_tensor = tf.map_fn(
decode_image_fn, elems=inputs,
fn_output_signature=tf.TensorSpec(
shape=[None] * len(input_image_size) + [num_channels],
dtype=tf.uint8),)
else:
image_tensor = inputs
return image_tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment