"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "2d831c6ef902ee3fcecbf3be6a5f3e43b645bad1"
Commit 14b7ac52 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 360539689
parent 4d5e9122
...@@ -16,17 +16,26 @@ ...@@ -16,17 +16,26 @@
import abc import abc
import functools import functools
from typing import Any, Dict, List, Optional, Text, Union from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
import tensorflow as tf import tensorflow as tf
from tensorflow.python.saved_model.model_utils import export_utils from tensorflow.python.saved_model.model_utils import export_utils
# TODO(hongkuny): add unit tests.
class ExportModule(tf.Module, metaclass=abc.ABCMeta): class ExportModule(tf.Module, metaclass=abc.ABCMeta):
"""Base Export Module.""" """Base Export Module."""
def __init__(self, params, model: tf.keras.Model, inference_step=None): def __init__(self,
params,
model: Union[tf.Module, tf.keras.Model],
inference_step: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel.
Args:
params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation.
inference_step: An optional callable to define how the model is called.
"""
super().__init__(name=None) super().__init__(name=None)
self.model = model self.model = model
self.params = params self.params = params
...@@ -38,7 +47,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -38,7 +47,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
self.model.__call__, training=False) self.model.__call__, training=False)
@abc.abstractmethod @abc.abstractmethod
def serve(self) -> Dict[str, tf.Tensor]: def serve(self) -> Mapping[Text, tf.Tensor]:
"""The bare inference function which should run on all devices. """The bare inference function which should run on all devices.
Expecting tensors are passed in through keyword arguments. Returns a Expecting tensors are passed in through keyword arguments. Returns a
...@@ -47,7 +56,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -47,7 +56,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def get_inference_signatures( def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Dict[str, Any]: self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
"""Get defined function signatures.""" """Get defined function signatures."""
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base."""
import os
from typing import Any, Dict, Mapping, Text
import tensorflow as tf
from official.core import export_base
class TestModule(export_base.ExportModule):
@tf.function
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
return {'outputs': self.inference_step(inputs)}
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
input_signature = tf.TensorSpec(shape=[None, None], dtype=tf.float32)
return {'foo': self.serve.get_concrete_function(input_signature)}
class ExportBaseTest(tf.test.TestCase):
def test_export_module(self):
tmp_dir = self.get_temp_dir()
model = tf.keras.layers.Dense(2)
inputs = tf.ones([2, 4], tf.float32)
expected_output = model(inputs, training=False)
module = TestModule(params=None, model=model)
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['foo'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=True)
self.assertTrue(os.path.exists(os.path.join(export_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(
os.path.join(export_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(export_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(export_dir)
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_custom_inference_step(self):
tmp_dir = self.get_temp_dir()
model = tf.keras.layers.Dense(2)
inputs = tf.ones([2, 4], tf.float32)
def _inference_step(inputs, model):
return tf.nn.softmax(model(inputs, training=False))
module = TestModule(
params=None, model=model, inference_step=_inference_step)
expected_output = _inference_step(inputs, model)
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['foo'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
if __name__ == '__main__':
tf.test.main()
...@@ -31,32 +31,30 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) ...@@ -31,32 +31,30 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class DetectionModule(export_base.ExportModule): class DetectionModule(export_base.ExportModule):
"""Detection Module.""" """Detection Module."""
def build_model(self): def _build_model(self):
if self._batch_size is None: if self._batch_size is None:
ValueError("batch_size can't be None for detection models") ValueError("batch_size can't be None for detection models")
if not self._params.task.model.detection_generator.use_batched_nms: if not self.params.task.model.detection_generator.use_batched_nms:
ValueError('Only batched_nms is supported.') ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] + input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3]) self._input_image_size + [3])
if isinstance(self._params.task.model, configs.maskrcnn.MaskRCNN): if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
self._model = factory.build_maskrcnn( model = factory.build_maskrcnn(
input_specs=input_specs, input_specs=input_specs, model_config=self.params.task.model)
model_config=self._params.task.model) elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
elif isinstance(self._params.task.model, configs.retinanet.RetinaNet): model = factory.build_retinanet(
self._model = factory.build_retinanet( input_specs=input_specs, model_config=self.params.task.model)
input_specs=input_specs,
model_config=self._params.task.model)
else: else:
raise ValueError('Detection module not implemented for {} model.'.format( raise ValueError('Detection module not implemented for {} model.'.format(
type(self._params.task.model))) type(self.params.task.model)))
return self._model return model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds detection model inputs for serving.""" """Builds detection model inputs for serving."""
model_params = self._params.task.model model_params = self.params.task.model
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB, offset=MEAN_RGB,
...@@ -81,7 +79,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -81,7 +79,7 @@ class DetectionModule(export_base.ExportModule):
return image, anchor_boxes, image_info return image, anchor_boxes, image_info
def _run_inference_on_image_tensors(self, images: tf.Tensor): def serve(self, images: tf.Tensor):
"""Cast image to float and run inference. """Cast image to float and run inference.
Args: Args:
...@@ -89,7 +87,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -89,7 +87,7 @@ class DetectionModule(export_base.ExportModule):
Returns: Returns:
Tensor holding detection output logits. Tensor holding detection output logits.
""" """
model_params = self._params.task.model model_params = self.params.task.model
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32) images = tf.cast(images, dtype=tf.float32)
...@@ -122,7 +120,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -122,7 +120,7 @@ class DetectionModule(export_base.ExportModule):
input_image_shape = image_info[:, 1, :] input_image_shape = image_info[:, 1, :]
detections = self._model.call( detections = self.model.call(
images=images, images=images,
image_shape=input_image_shape, image_shape=input_image_shape,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
......
...@@ -38,35 +38,10 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -38,35 +38,10 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
params, batch_size=1, input_image_size=[640, 640]) params, batch_size=1, input_image_size=[640, 640])
return detection_module return detection_module
def _export_from_module(self, module, input_type, batch_size, save_directory): def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor': signatures = module.get_inference_signatures(
input_signature = tf.TensorSpec( {input_type: 'serving_default'})
shape=[batch_size, None, None, 3], dtype=tf.uint8) tf.saved_model.save(module, save_directory, signatures=signatures)
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _get_dummy_input(self, input_type, batch_size, image_size): def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type.""" """Get dummy input for the given input type."""
...@@ -107,23 +82,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -107,23 +82,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_export(self, input_type, experiment_name, image_size): def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
batch_size = 1
module = self._get_detection_module(experiment_name) module = self._get_detection_module(experiment_name)
model = module.build_model()
self._export_from_module(module, input_type, batch_size, tmp_dir) self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb'))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists( self.assertTrue(
os.path.join(tmp_dir, 'variables', 'variables.index'))) os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists( self.assertTrue(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001'))) os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir) imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default'] detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type, batch_size, image_size) images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
processed_images, anchor_boxes, image_info = module._build_inputs( processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8)) tf.zeros((224, 224, 3), dtype=tf.uint8))
...@@ -133,7 +108,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -133,7 +108,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
for l, l_boxes in anchor_boxes.items(): for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0) anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = model( expected_outputs = module.model(
images=processed_images, images=processed_images,
image_shape=image_shape, image_shape=image_shape,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
...@@ -143,5 +118,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -143,5 +118,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(outputs['num_detections'].numpy(), self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy()) expected_outputs['num_detections'].numpy())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -16,20 +16,22 @@ ...@@ -16,20 +16,22 @@
"""Base class for model export.""" """Base class for model export."""
import abc import abc
from typing import Optional, Sequence, Mapping from typing import Dict, List, Mapping, Optional, Text
import tensorflow as tf import tensorflow as tf
from official.core import export_base
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
class ExportModule(tf.Module, metaclass=abc.ABCMeta): class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
"""Base Export Module.""" """Base Export Module."""
def __init__(self, def __init__(self,
params: cfg.ExperimentConfig, params: cfg.ExperimentConfig,
*,
batch_size: int, batch_size: int,
input_image_size: Sequence[int], input_image_size: List[int],
num_channels: int = 3, num_channels: int = 3,
model: Optional[tf.keras.Model] = None): model: Optional[tf.keras.Model] = None):
"""Initializes a module for export. """Initializes a module for export.
...@@ -42,13 +44,13 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -42,13 +44,13 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
num_channels: The number of the image channels. num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported. model: A tf.keras.Model instance to be exported.
""" """
self.params = params
super(ExportModule, self).__init__()
self._params = params
self._batch_size = batch_size self._batch_size = batch_size
self._input_image_size = input_image_size self._input_image_size = input_image_size
self._num_channels = num_channels self._num_channels = num_channels
self._model = model if model is None:
model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model)
def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor: def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
"""Decodes an image bytes to an image tensor. """Decodes an image bytes to an image tensor.
...@@ -92,45 +94,40 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -92,45 +94,40 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
image_tensor = self._decode_image(parsed_tensors['image/encoded']) image_tensor = self._decode_image(parsed_tensors['image/encoded'])
return image_tensor return image_tensor
@abc.abstractmethod def _build_model(self, **kwargs):
def build_model(self, **kwargs): """Returns a model built from the params."""
"""Builds model and sets self._model.""" return None
@abc.abstractmethod
def _run_inference_on_image_tensors(
self, images: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Runs inference on images."""
@tf.function @tf.function
def inference_from_image_tensors( def inference_from_image_tensors(
self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]: self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self._run_inference_on_image_tensors(input_tensor) return self.serve(inputs)
@tf.function @tf.function
def inference_from_image_bytes(self, input_tensor: str): def inference_from_image_bytes(self, inputs: tf.Tensor):
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
self._decode_image, self._decode_image,
elems=input_tensor, elems=inputs,
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[None] * len(self._input_image_size) + shape=[None] * len(self._input_image_size) +
[self._num_channels], [self._num_channels],
dtype=tf.uint8), dtype=tf.uint8),
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
return self._run_inference_on_image_tensors(images) return self.serve(images)
@tf.function @tf.function
def inference_from_tf_example( def inference_from_tf_example(self,
self, input_tensor: tf.train.Example) -> Mapping[str, tf.Tensor]: inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
self._decode_tf_example, self._decode_tf_example,
elems=input_tensor, elems=inputs,
# Height/width of the shape of input images is unspecified (None) # Height/width of the shape of input images is unspecified (None)
# at the time of decoding the example, but the shape will # at the time of decoding the example, but the shape will
# be adjusted to conform to the input layer of the model, # be adjusted to conform to the input layer of the model,
...@@ -142,4 +139,41 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -142,4 +139,41 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
dtype=tf.uint8, dtype=tf.uint8,
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
return self._run_inference_on_image_tensors(images) return self.serve(images)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for key, def_name in function_keys.items():
if key == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + [None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8)
signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature)
elif key == 'image_bytes':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_image_bytes.get_concrete_function(
input_signature)
elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
...@@ -16,16 +16,15 @@ ...@@ -16,16 +16,15 @@
r"""Vision models export utility function for serving/inference.""" r"""Vision models export utility function for serving/inference."""
import os import os
from typing import Optional, List from typing import Optional, List
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils from official.core import train_utils
from official.vision.beta import configs from official.vision.beta import configs
from official.vision.beta.serving import detection from official.vision.beta.serving import detection
from official.vision.beta.serving import export_base
from official.vision.beta.serving import image_classification from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation from official.vision.beta.serving import semantic_segmentation
...@@ -75,6 +74,7 @@ def export_inference_graph( ...@@ -75,6 +74,7 @@ def export_inference_graph(
else: else:
output_saved_model_directory = export_dir output_saved_model_directory = export_dir
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
if not export_module: if not export_module:
if isinstance(params.task, if isinstance(params.task,
configs.image_classification.ImageClassificationTask): configs.image_classification.ImageClassificationTask):
...@@ -101,47 +101,13 @@ def export_inference_graph( ...@@ -101,47 +101,13 @@ def export_inference_graph(
raise ValueError('Export module not implemented for {} task.'.format( raise ValueError('Export module not implemented for {} task.'.format(
type(params.task))) type(params.task)))
model = export_module.build_model() export_base.export(
export_module,
ckpt = tf.train.Checkpoint(model=model) function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
ckpt_dir_or_file = checkpoint_path checkpoint_path=checkpoint_path,
if tf.io.gfile.isdir(ckpt_dir_or_file): timestamped=False)
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
status = ckpt.restore(ckpt_dir_or_file).expect_partial()
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[batch_size] + [None] * len(input_image_size) + [num_channels],
dtype=tf.uint8)
signatures = {
'serving_default':
export_module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
status.assert_existing_objects_matched()
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt')) ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
tf.saved_model.save(export_module,
output_saved_model_directory,
signatures=signatures)
train_utils.serialize_config(params, export_dir) train_utils.serialize_config(params, export_dir)
...@@ -24,7 +24,7 @@ import tensorflow as tf ...@@ -24,7 +24,7 @@ import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.serving import image_classification from official.vision.beta.modeling import factory
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -68,10 +68,14 @@ def export_model_to_tfhub(params, ...@@ -68,10 +68,14 @@ def export_model_to_tfhub(params,
checkpoint_path, checkpoint_path,
export_path): export_path):
"""Export an image classification model to TF-Hub.""" """Export an image classification model to TF-Hub."""
export_module = image_classification.ClassificationModule( input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
params=params, batch_size=batch_size, input_image_size=input_image_size) input_image_size + [3])
model = export_module.build_model(skip_logits_layer=skip_logits_layer) model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None,
skip_logits_layer=skip_logits_layer)
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(checkpoint_path).assert_existing_objects_matched() checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
model.save(export_path, include_optimizer=False, save_format='tf') model.save(export_path, include_optimizer=False, save_format='tf')
......
...@@ -29,17 +29,14 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) ...@@ -29,17 +29,14 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class ClassificationModule(export_base.ExportModule): class ClassificationModule(export_base.ExportModule):
"""classification Module.""" """classification Module."""
def build_model(self, skip_logits_layer=False): def _build_model(self):
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3]) shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_classification_model( return factory.build_classification_model(
input_specs=input_specs, input_specs=input_specs,
model_config=self._params.task.model, model_config=self.params.task.model,
l2_regularizer=None, l2_regularizer=None)
skip_logits_layer=skip_logits_layer)
return self._model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds classification model inputs for serving.""" """Builds classification model inputs for serving."""
...@@ -58,7 +55,7 @@ class ClassificationModule(export_base.ExportModule): ...@@ -58,7 +55,7 @@ class ClassificationModule(export_base.ExportModule):
scale=STDDEV_RGB) scale=STDDEV_RGB)
return image return image
def _run_inference_on_image_tensors(self, images): def serve(self, images):
"""Cast image to float and run inference. """Cast image to float and run inference.
Args: Args:
...@@ -79,6 +76,6 @@ class ClassificationModule(export_base.ExportModule): ...@@ -79,6 +76,6 @@ class ClassificationModule(export_base.ExportModule):
) )
) )
logits = self._model(images, training=False) logits = self.inference_step(images)
return dict(outputs=logits) return dict(outputs=logits)
...@@ -38,30 +38,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -38,30 +38,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
return classification_module return classification_module
def _export_from_module(self, module, input_type, save_directory): def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor': signatures = module.get_inference_signatures(
input_signature = tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8) {input_type: 'serving_default'})
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module, tf.saved_model.save(module,
save_directory, save_directory,
signatures=signatures) signatures=signatures)
...@@ -95,9 +73,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,9 +73,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_export(self, input_type='image_tensor'): def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
module = self._get_classification_module() module = self._get_classification_module()
model = module.build_model()
self._export_from_module(module, input_type, tmp_dir) self._export_from_module(module, input_type, tmp_dir)
...@@ -118,7 +94,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -118,7 +94,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8), elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32))) shape=[224, 224, 3], dtype=tf.float32)))
expected_output = model(processed_images, training=False) expected_output = module.model(processed_images, training=False)
out = classification_fn(tf.constant(images)) out = classification_fn(tf.constant(images))
self.assertAllClose(out['outputs'].numpy(), expected_output.numpy()) self.assertAllClose(out['outputs'].numpy(), expected_output.numpy())
......
...@@ -29,17 +29,15 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) ...@@ -29,17 +29,15 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class SegmentationModule(export_base.ExportModule): class SegmentationModule(export_base.ExportModule):
"""Segmentation Module.""" """Segmentation Module."""
def build_model(self): def _build_model(self):
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3]) shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_segmentation_model( return factory.build_segmentation_model(
input_specs=input_specs, input_specs=input_specs,
model_config=self._params.task.model, model_config=self.params.task.model,
l2_regularizer=None) l2_regularizer=None)
return self._model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds classification model inputs for serving.""" """Builds classification model inputs for serving."""
...@@ -56,7 +54,7 @@ class SegmentationModule(export_base.ExportModule): ...@@ -56,7 +54,7 @@ class SegmentationModule(export_base.ExportModule):
aug_scale_max=1.0) aug_scale_max=1.0)
return image return image
def _run_inference_on_image_tensors(self, images): def serve(self, images):
"""Cast image to float and run inference. """Cast image to float and run inference.
Args: Args:
...@@ -77,7 +75,7 @@ class SegmentationModule(export_base.ExportModule): ...@@ -77,7 +75,7 @@ class SegmentationModule(export_base.ExportModule):
) )
) )
masks = self._model(images, training=False) masks = self.inference_step(images)
masks = tf.image.resize(masks, self._input_image_size, method='bilinear') masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
return dict(predicted_masks=masks) return dict(predicted_masks=masks)
...@@ -38,33 +38,9 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -38,33 +38,9 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
return segmentation_module return segmentation_module
def _export_from_module(self, module, input_type, save_directory): def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor': signatures = module.get_inference_signatures(
input_signature = tf.TensorSpec(shape=[None, 112, 112, 3], dtype=tf.uint8) {input_type: 'serving_default'})
signatures = { tf.saved_model.save(module, save_directory, signatures=signatures)
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _get_dummy_input(self, input_type): def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type.""" """Get dummy input for the given input type."""
...@@ -95,17 +71,17 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,17 +71,17 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_export(self, input_type='image_tensor'): def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module() module = self._get_segmentation_module()
model = module.build_model()
self._export_from_module(module, input_type, tmp_dir) self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb'))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists( self.assertTrue(
os.path.join(tmp_dir, 'variables', 'variables.index'))) os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists( self.assertTrue(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001'))) os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir) imported = tf.saved_model.load(tmp_dir)
segmentation_fn = imported.signatures['serving_default'] segmentation_fn = imported.signatures['serving_default']
...@@ -119,9 +95,11 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -119,9 +95,11 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32))) shape=[112, 112, 3], dtype=tf.float32)))
expected_output = tf.image.resize( expected_output = tf.image.resize(
model(processed_images, training=False), [112, 112], method='bilinear') module.model(processed_images, training=False), [112, 112],
method='bilinear')
out = segmentation_fn(tf.constant(images)) out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['predicted_masks'].numpy(), expected_output.numpy()) self.assertAllClose(out['predicted_masks'].numpy(), expected_output.numpy())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment