Commit b287dfd9 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add a normalize_image_impl function without tf.image.convert_image_dtype as it...

Add a normalize_image_impl function without tf.image.convert_image_dtype as it is not supported in TFLite. By using normalize_image_impl, normalization can be included in the TFLite graph. The original normalize_image calls normalize_image_impl so there is no affect on existing code.

PiperOrigin-RevId: 467278332
parent 14ed98a7
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Preprocessing ops.""" """Preprocessing ops."""
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Sequence, Union
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
...@@ -65,22 +65,42 @@ def clip_or_pad_to_fixed_size(input_tensor, size, constant_values=0): ...@@ -65,22 +65,42 @@ def clip_or_pad_to_fixed_size(input_tensor, size, constant_values=0):
return padded_tensor return padded_tensor
def normalize_image(image, def normalize_image(image: tf.Tensor,
offset=(0.485, 0.456, 0.406), offset: Sequence[float] = (0.485, 0.456, 0.406),
scale=(0.229, 0.224, 0.225)): scale: Sequence[float] = (0.229, 0.224, 0.225)):
"""Normalizes the image to zero mean and unit variance.""" """Normalizes the image to zero mean and unit variance."""
with tf.name_scope('normalize_image'): with tf.name_scope('normalize_image'):
image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.convert_image_dtype(image, dtype=tf.float32)
offset = tf.constant(offset) return normalize_scaled_float_image(image, offset, scale)
offset = tf.expand_dims(offset, axis=0)
offset = tf.expand_dims(offset, axis=0)
image -= offset def normalize_scaled_float_image(image: tf.Tensor,
offset: Sequence[float] = (0.485, 0.456,
scale = tf.constant(scale) 0.406),
scale = tf.expand_dims(scale, axis=0) scale: Sequence[float] = (0.229, 0.224,
scale = tf.expand_dims(scale, axis=0) 0.225)):
image /= scale """Normalizes a scaled float image to zero mean and unit variance.
return image
It assumes the input image is float dtype with values in [0, 1).
Args:
image: A tf.Tensor in float32 dtype with values in range [0, 1).
offset: A tuple of mean values to be subtracted from the image.
scale: A tuple of normalization factors.
Returns:
A normalized image tensor.
"""
offset = tf.constant(offset)
offset = tf.expand_dims(offset, axis=0)
offset = tf.expand_dims(offset, axis=0)
image -= offset
scale = tf.constant(scale)
scale = tf.expand_dims(scale, axis=0)
scale = tf.expand_dims(scale, axis=0)
image /= scale
return image
def compute_padded_size(desired_size, stride): def compute_padded_size(desired_size, stride):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment