Commit 4a7bf679 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 360539689
parent 3b078049
......@@ -16,17 +16,26 @@
import abc
import functools
from typing import Any, Dict, List, Optional, Text, Union
from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
import tensorflow as tf
from tensorflow.python.saved_model.model_utils import export_utils
# TODO(hongkuny): add unit tests.
class ExportModule(tf.Module, metaclass=abc.ABCMeta):
"""Base Export Module."""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
def __init__(self,
params,
model: Union[tf.Module, tf.keras.Model],
inference_step: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel.
Args:
params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation.
inference_step: An optional callable to define how the model is called.
"""
super().__init__(name=None)
self.model = model
self.params = params
......@@ -38,7 +47,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
self.model.__call__, training=False)
@abc.abstractmethod
def serve(self) -> Dict[str, tf.Tensor]:
def serve(self) -> Mapping[Text, tf.Tensor]:
"""The bare inference function which should run on all devices.
Expecting tensors are passed in through keyword arguments. Returns a
......@@ -47,7 +56,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Dict[str, Any]:
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
"""Get defined function signatures."""
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base."""
import os
from typing import Any, Dict, Mapping, Text
import tensorflow as tf
from official.core import export_base
class TestModule(export_base.ExportModule):
@tf.function
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
return {'outputs': self.inference_step(inputs)}
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
input_signature = tf.TensorSpec(shape=[None, None], dtype=tf.float32)
return {'foo': self.serve.get_concrete_function(input_signature)}
class ExportBaseTest(tf.test.TestCase):
def test_export_module(self):
tmp_dir = self.get_temp_dir()
model = tf.keras.layers.Dense(2)
inputs = tf.ones([2, 4], tf.float32)
expected_output = model(inputs, training=False)
module = TestModule(params=None, model=model)
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['foo'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=True)
self.assertTrue(os.path.exists(os.path.join(export_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(
os.path.join(export_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(export_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(export_dir)
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_custom_inference_step(self):
tmp_dir = self.get_temp_dir()
model = tf.keras.layers.Dense(2)
inputs = tf.ones([2, 4], tf.float32)
def _inference_step(inputs, model):
return tf.nn.softmax(model(inputs, training=False))
module = TestModule(
params=None, model=model, inference_step=_inference_step)
expected_output = _inference_step(inputs, model)
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['foo'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
if __name__ == '__main__':
tf.test.main()
......@@ -31,32 +31,30 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class DetectionModule(export_base.ExportModule):
"""Detection Module."""
def build_model(self):
def _build_model(self):
if self._batch_size is None:
ValueError("batch_size can't be None for detection models")
if not self._params.task.model.detection_generator.use_batched_nms:
if not self.params.task.model.detection_generator.use_batched_nms:
ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
if isinstance(self._params.task.model, configs.maskrcnn.MaskRCNN):
self._model = factory.build_maskrcnn(
input_specs=input_specs,
model_config=self._params.task.model)
elif isinstance(self._params.task.model, configs.retinanet.RetinaNet):
self._model = factory.build_retinanet(
input_specs=input_specs,
model_config=self._params.task.model)
if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
model = factory.build_maskrcnn(
input_specs=input_specs, model_config=self.params.task.model)
elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
model = factory.build_retinanet(
input_specs=input_specs, model_config=self.params.task.model)
else:
raise ValueError('Detection module not implemented for {} model.'.format(
type(self._params.task.model)))
type(self.params.task.model)))
return self._model
return model
def _build_inputs(self, image):
"""Builds detection model inputs for serving."""
model_params = self._params.task.model
model_params = self.params.task.model
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
......@@ -81,7 +79,7 @@ class DetectionModule(export_base.ExportModule):
return image, anchor_boxes, image_info
def _run_inference_on_image_tensors(self, images: tf.Tensor):
def serve(self, images: tf.Tensor):
"""Cast image to float and run inference.
Args:
......@@ -89,7 +87,7 @@ class DetectionModule(export_base.ExportModule):
Returns:
Tensor holding detection output logits.
"""
model_params = self._params.task.model
model_params = self.params.task.model
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
......@@ -122,7 +120,7 @@ class DetectionModule(export_base.ExportModule):
input_image_shape = image_info[:, 1, :]
detections = self._model.call(
detections = self.model.call(
images=images,
image_shape=input_image_shape,
anchor_boxes=anchor_boxes,
......
......@@ -38,35 +38,10 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
params, batch_size=1, input_image_size=[640, 640])
return detection_module
def _export_from_module(self, module, input_type, batch_size, save_directory):
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[batch_size, None, None, 3], dtype=tf.uint8)
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type."""
......@@ -107,23 +82,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
)
def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir()
batch_size = 1
module = self._get_detection_module(experiment_name)
model = module.build_model()
self._export_from_module(module, input_type, batch_size, tmp_dir)
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type, batch_size, image_size)
images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
......@@ -133,7 +108,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = model(
expected_outputs = module.model(
images=processed_images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
......@@ -143,5 +118,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy())
if __name__ == '__main__':
tf.test.main()
......@@ -16,20 +16,22 @@
"""Base class for model export."""
import abc
from typing import Optional, Sequence, Mapping
from typing import Dict, List, Mapping, Optional, Text
import tensorflow as tf
from official.core import export_base
from official.modeling.hyperparams import config_definitions as cfg
class ExportModule(tf.Module, metaclass=abc.ABCMeta):
class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
"""Base Export Module."""
def __init__(self,
params: cfg.ExperimentConfig,
*,
batch_size: int,
input_image_size: Sequence[int],
input_image_size: List[int],
num_channels: int = 3,
model: Optional[tf.keras.Model] = None):
"""Initializes a module for export.
......@@ -42,13 +44,13 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported.
"""
super(ExportModule, self).__init__()
self._params = params
self.params = params
self._batch_size = batch_size
self._input_image_size = input_image_size
self._num_channels = num_channels
self._model = model
if model is None:
model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model)
def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
"""Decodes an image bytes to an image tensor.
......@@ -92,45 +94,40 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
image_tensor = self._decode_image(parsed_tensors['image/encoded'])
return image_tensor
@abc.abstractmethod
def build_model(self, **kwargs):
"""Builds model and sets self._model."""
@abc.abstractmethod
def _run_inference_on_image_tensors(
self, images: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Runs inference on images."""
def _build_model(self, **kwargs):
"""Returns a model built from the params."""
return None
@tf.function
def inference_from_image_tensors(
self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self._run_inference_on_image_tensors(input_tensor)
self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(inputs)
@tf.function
def inference_from_image_bytes(self, input_tensor: str):
def inference_from_image_bytes(self, inputs: tf.Tensor):
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._decode_image,
elems=input_tensor,
elems=inputs,
fn_output_signature=tf.TensorSpec(
shape=[None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8),
parallel_iterations=32))
images = tf.stack(images)
return self._run_inference_on_image_tensors(images)
return self.serve(images)
@tf.function
def inference_from_tf_example(
self, input_tensor: tf.train.Example) -> Mapping[str, tf.Tensor]:
def inference_from_tf_example(self,
inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._decode_tf_example,
elems=input_tensor,
elems=inputs,
# Height/width of the shape of input images is unspecified (None)
# at the time of decoding the example, but the shape will
# be adjusted to conform to the input layer of the model,
......@@ -142,4 +139,41 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
dtype=tf.uint8,
parallel_iterations=32))
images = tf.stack(images)
return self._run_inference_on_image_tensors(images)
return self.serve(images)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for key, def_name in function_keys.items():
if key == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + [None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8)
signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature)
elif key == 'image_bytes':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_image_bytes.get_concrete_function(
input_signature)
elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
......@@ -16,16 +16,15 @@
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision.beta import configs
from official.vision.beta.serving import detection
from official.vision.beta.serving import export_base
from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation
......@@ -75,6 +74,7 @@ def export_inference_graph(
else:
output_saved_model_directory = export_dir
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
if not export_module:
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
......@@ -101,47 +101,13 @@ def export_inference_graph(
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
model = export_module.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_dir_or_file = checkpoint_path
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
status = ckpt.restore(ckpt_dir_or_file).expect_partial()
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[batch_size] + [None] * len(input_image_size) + [num_channels],
dtype=tf.uint8)
signatures = {
'serving_default':
export_module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
status.assert_existing_objects_matched()
export_base.export(
export_module,
function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint_path=checkpoint_path,
timestamped=False)
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
tf.saved_model.save(export_module,
output_saved_model_directory,
signatures=signatures)
train_utils.serialize_config(params, export_dir)
......@@ -24,7 +24,7 @@ import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.serving import image_classification
from official.vision.beta.modeling import factory
FLAGS = flags.FLAGS
......@@ -68,10 +68,14 @@ def export_model_to_tfhub(params,
checkpoint_path,
export_path):
"""Export an image classification model to TF-Hub."""
export_module = image_classification.ClassificationModule(
params=params, batch_size=batch_size, input_image_size=input_image_size)
input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
input_image_size + [3])
model = export_module.build_model(skip_logits_layer=skip_logits_layer)
model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None,
skip_logits_layer=skip_logits_layer)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
model.save(export_path, include_optimizer=False, save_format='tf')
......
......@@ -29,17 +29,14 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class ClassificationModule(export_base.ExportModule):
"""classification Module."""
def build_model(self, skip_logits_layer=False):
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_classification_model(
return factory.build_classification_model(
input_specs=input_specs,
model_config=self._params.task.model,
l2_regularizer=None,
skip_logits_layer=skip_logits_layer)
return self._model
model_config=self.params.task.model,
l2_regularizer=None)
def _build_inputs(self, image):
"""Builds classification model inputs for serving."""
......@@ -58,7 +55,7 @@ class ClassificationModule(export_base.ExportModule):
scale=STDDEV_RGB)
return image
def _run_inference_on_image_tensors(self, images):
def serve(self, images):
"""Cast image to float and run inference.
Args:
......@@ -79,6 +76,6 @@ class ClassificationModule(export_base.ExportModule):
)
)
logits = self._model(images, training=False)
logits = self.inference_step(images)
return dict(outputs=logits)
......@@ -38,30 +38,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
return classification_module
def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module,
save_directory,
signatures=signatures)
......@@ -95,9 +73,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
)
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_classification_module()
model = module.build_model()
self._export_from_module(module, input_type, tmp_dir)
......@@ -118,7 +94,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32)))
expected_output = model(processed_images, training=False)
expected_output = module.model(processed_images, training=False)
out = classification_fn(tf.constant(images))
self.assertAllClose(out['outputs'].numpy(), expected_output.numpy())
......
......@@ -29,17 +29,15 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class SegmentationModule(export_base.ExportModule):
"""Segmentation Module."""
def build_model(self):
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_segmentation_model(
return factory.build_segmentation_model(
input_specs=input_specs,
model_config=self._params.task.model,
model_config=self.params.task.model,
l2_regularizer=None)
return self._model
def _build_inputs(self, image):
"""Builds classification model inputs for serving."""
......@@ -56,7 +54,7 @@ class SegmentationModule(export_base.ExportModule):
aug_scale_max=1.0)
return image
def _run_inference_on_image_tensors(self, images):
def serve(self, images):
"""Cast image to float and run inference.
Args:
......@@ -77,7 +75,7 @@ class SegmentationModule(export_base.ExportModule):
)
)
masks = self._model(images, training=False)
masks = self.inference_step(images)
masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
return dict(predicted_masks=masks)
......@@ -38,33 +38,9 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
return segmentation_module
def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(shape=[None, 112, 112, 3], dtype=tf.uint8)
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module,
save_directory,
signatures=signatures)
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
......@@ -95,17 +71,17 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
)
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module()
model = module.build_model()
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
segmentation_fn = imported.signatures['serving_default']
......@@ -119,9 +95,11 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
fn_output_signature=tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32)))
expected_output = tf.image.resize(
model(processed_images, training=False), [112, 112], method='bilinear')
module.model(processed_images, training=False), [112, 112],
method='bilinear')
out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['predicted_masks'].numpy(), expected_output.numpy())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment