Commit 5e8370ca authored by Jiageng Zhang's avatar Jiageng Zhang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 475928219
parent d5cd9b0a
......@@ -17,7 +17,7 @@
import dataclasses
import imghdr
import io
from typing import Tuple
from typing import Optional, Tuple
import numpy as np
from PIL import Image
......@@ -35,6 +35,7 @@ class ImageFormat:
bmp: str = 'BMP'
png: str = 'PNG'
jpeg: str = 'JPEG'
raw: str = 'RAW'
def validate_image_format(format_str: str) -> str:
......@@ -66,8 +67,11 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
image_format: An enum specifying the format of the generated image.
Returns:
Encoded image string
Encoded image string.
"""
if image_format == 'RAW':
return image_np.tobytes()
if len(image_np.shape) > 2 and image_np.shape[2] == 1:
image_pil = Image.fromarray(np.squeeze(image_np), 'L')
else:
......@@ -77,7 +81,12 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
return output.getvalue()
def decode_image(image_bytes: bytes) -> np.ndarray:
def decode_image(image_bytes: bytes,
image_format: Optional[str] = None,
image_dtype: str = 'uint8') -> np.ndarray:
"""Decodes image_bytes into numpy array."""
if image_format == 'RAW':
return np.frombuffer(image_bytes, dtype=image_dtype)
image_pil = Image.open(io.BytesIO(image_bytes))
image_np = np.array(image_pil)
if len(image_np.shape) < 3:
......@@ -88,6 +97,9 @@ def decode_image(image_bytes: bytes) -> np.ndarray:
def decode_image_metadata(image_bytes: bytes) -> Tuple[int, int, int, str]:
"""Decodes image metadata from encoded image string.
Note that if the image is encoded in RAW format, the metadata cannot be
inferred from the image bytes.
Args:
image_bytes: Encoded image string.
......
......@@ -39,6 +39,21 @@ class ImageUtilsTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllClose(actual_image_np, image_np)
self.assertEqual(actual_image_np.shape, image_np.shape)
@parameterized.named_parameters(
('RGB_RAW', 128, 64, 3, tf.bfloat16.as_numpy_dtype),
('GREY_RAW', 32, 32, 1, tf.uint8.as_numpy_dtype))
def test_encode_raw_image_then_decode_raw_image(self, height, width,
num_channels, image_dtype):
image_np = fake_feature_generator.generate_image_np(height, width,
num_channels)
image_np = image_np.astype(image_dtype)
image_str = image_utils.encode_image(image_np, 'RAW')
actual_image_np = image_utils.decode_image(image_str, 'RAW', image_dtype)
actual_image_np = actual_image_np.reshape([height, width, num_channels])
self.assertAllClose(actual_image_np, image_np)
self.assertEqual(actual_image_np.shape, image_np.shape)
@parameterized.named_parameters(
('RGB_PNG', 128, 64, 3, 'PNG'), ('RGB_JPEG', 64, 128, 3, 'JPEG'),
('GREY_BMP', 32, 32, 1, 'BMP'), ('GREY_PNG', 128, 128, 1, 'png'))
......
......@@ -119,6 +119,10 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
Returns:
The builder object for subsequent method calls.
"""
if image_format == 'RAW':
if not (height and width and num_channels):
raise ValueError('For raw image feature, height, width and '
'num_channels fields are required.')
if not all((height, width, num_channels, image_format)):
(height, width, num_channels, image_format) = (
image_utils.decode_image_metadata(encoded_image))
......
......@@ -24,6 +24,7 @@ from official.vision.data import tf_example_builder
class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('RGB_PNG', 128, 64, 3, 'PNG', '15125990'),
('RGB_RAW', 128, 128, 3, 'RAW', '5607919'),
('RGB_JPEG', 64, 128, 3, 'JPEG', '3107796'))
def test_add_image_matrix_feature_success(self, height, width, num_channels,
image_format, hashed_image):
......@@ -111,6 +112,62 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[hashed_image]))
})), example)
def test_add_encoded_raw_image_feature_success(self):
height = 128
width = 128
num_channels = 3
image_format = 'RAW'
image_bytes = tf.bfloat16.as_numpy_dtype
image_np = fake_feature_generator.generate_image_np(height, width,
num_channels)
image_np = image_np.astype(image_bytes)
expected_image_bytes = image_utils.encode_image(image_np, image_format)
hashed_image = bytes('3572575', 'ascii')
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_encoded_image_feature(expected_image_bytes, 'RAW',
height, width, num_channels)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[expected_image_bytes])),
'image/format':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes(image_format, 'ascii')])),
'image/height':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[height])),
'image/width':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[width])),
'image/channels':
tf.train.Feature(
int64_list=tf.train.Int64List(
value=[num_channels])),
'image/source_id':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image]))
})), example)
def test_add_encoded_raw_image_feature_valueerror(self):
image_format = 'RAW'
image_bytes = tf.bfloat16.as_numpy_dtype
image_np = fake_feature_generator.generate_image_np(1, 1, 1)
image_np = image_np.astype(image_bytes)
expected_image_bytes = image_utils.encode_image(image_np, image_format)
example_builder = tf_example_builder.TfExampleBuilder()
with self.assertRaises(ValueError):
example_builder.add_encoded_image_feature(expected_image_bytes,
image_format)
@parameterized.parameters(
(True, True, True, True, True),
(False, False, False, False, False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment