Commit 675f4de3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 355199395
parent c90f8b16
...@@ -55,7 +55,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -55,7 +55,7 @@ class DetectionModule(export_base.ExportModule):
return self._model return self._model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds classification model inputs for serving.""" """Builds detection model inputs for serving."""
model_params = self._params.task.model model_params = self._params.task.model
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
...@@ -89,7 +89,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -89,7 +89,7 @@ class DetectionModule(export_base.ExportModule):
Args: Args:
images: uint8 Tensor of shape [batch_size, None, None, 3] images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns: Returns:
Tensor holding classification output logits. Tensor holding detection output logits.
""" """
model_params = self._params.task.model model_params = self._params.task.model
with tf.device('cpu:0'): with tf.device('cpu:0'):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Test for image classification export lib.""" """Test for image detection export lib."""
import io import io
import os import os
...@@ -41,7 +41,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -41,7 +41,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _export_from_module(self, module, input_type, batch_size, save_directory): def _export_from_module(self, module, input_type, batch_size, save_directory):
if input_type == 'image_tensor': if input_type == 'image_tensor':
input_signature = tf.TensorSpec( input_signature = tf.TensorSpec(
shape=[batch_size, 640, 640, 3], dtype=tf.uint8) shape=[batch_size, None, None, 3], dtype=tf.uint8)
signatures = { signatures = {
'serving_default': 'serving_default':
module.inference_from_image_tensors.get_concrete_function( module.inference_from_image_tensors.get_concrete_function(
...@@ -68,18 +68,19 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -68,18 +68,19 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
save_directory, save_directory,
signatures=signatures) signatures=signatures)
def _get_dummy_input(self, input_type, batch_size): def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type.""" """Get dummy input for the given input type."""
h, w = image_size
if input_type == 'image_tensor': if input_type == 'image_tensor':
return tf.zeros((batch_size, 640, 640, 3), dtype=np.uint8) return tf.zeros((batch_size, h, w, 3), dtype=np.uint8)
elif input_type == 'image_bytes': elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((640, 640, 3), dtype=np.uint8)) image = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8))
byte_io = io.BytesIO() byte_io = io.BytesIO()
image.save(byte_io, 'PNG') image.save(byte_io, 'PNG')
return [byte_io.getvalue() for b in range(batch_size)] return [byte_io.getvalue() for b in range(batch_size)]
elif input_type == 'tf_example': elif input_type == 'tf_example':
image_tensor = tf.zeros((640, 640, 3), dtype=tf.uint8) image_tensor = tf.zeros((h, w, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy() encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example( example = tf.train.Example(
features=tf.train.Features( features=tf.train.Features(
...@@ -91,21 +92,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -91,21 +92,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
return [example for b in range(batch_size)] return [example for b in range(batch_size)]
@parameterized.parameters( @parameterized.parameters(
('image_tensor', 'fasterrcnn_resnetfpn_coco'), ('image_tensor', 'fasterrcnn_resnetfpn_coco', [384, 384]),
('image_bytes', 'fasterrcnn_resnetfpn_coco'), ('image_bytes', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tf_example', 'fasterrcnn_resnetfpn_coco'), ('tf_example', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'maskrcnn_resnetfpn_coco'), ('image_tensor', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_bytes', 'maskrcnn_resnetfpn_coco'), ('image_bytes', 'maskrcnn_resnetfpn_coco', [640, 384]),
('tf_example', 'maskrcnn_resnetfpn_coco'), ('tf_example', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco'), ('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]),
('image_bytes', 'retinanet_resnetfpn_coco'), ('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]),
('tf_example', 'retinanet_resnetfpn_coco'), ('tf_example', 'retinanet_resnetfpn_coco', [384, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]),
('image_bytes', 'retinanet_spinenet_coco', [640, 640]),
('tf_example', 'retinanet_spinenet_coco', [640, 384]),
) )
def test_export(self, input_type, experiment_name): def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
batch_size = 1 batch_size = 1
experiment_name = 'fasterrcnn_resnetfpn_coco'
module = self._get_detection_module(experiment_name) module = self._get_detection_module(experiment_name)
model = module.build_model() model = module.build_model()
...@@ -118,9 +121,9 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -118,9 +121,9 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001'))) os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir) imported = tf.saved_model.load(tmp_dir)
classification_fn = imported.signatures['serving_default'] detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type, batch_size) images = self._get_dummy_input(input_type, batch_size, image_size)
processed_images, anchor_boxes, image_shape = module._build_inputs( processed_images, anchor_boxes, image_shape = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8)) tf.zeros((224, 224, 3), dtype=tf.uint8))
...@@ -134,7 +137,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -134,7 +137,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
image_shape=image_shape, image_shape=image_shape,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
training=False) training=False)
outputs = classification_fn(tf.constant(images)) outputs = detection_fn(tf.constant(images))
self.assertAllClose(outputs['num_detections'].numpy(), self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy()) expected_outputs['num_detections'].numpy())
......
...@@ -73,7 +73,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -73,7 +73,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
_decode_image, _decode_image,
elems=input_tensor, elems=input_tensor,
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.uint8), shape=[None, None, 3], dtype=tf.uint8),
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
return self._run_inference_on_image_tensors(images) return self._run_inference_on_image_tensors(images)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment