Commit 5e8370ca authored by Jiageng Zhang's avatar Jiageng Zhang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 475928219
parent d5cd9b0a
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import dataclasses import dataclasses
import imghdr import imghdr
import io import io
from typing import Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -35,6 +35,7 @@ class ImageFormat: ...@@ -35,6 +35,7 @@ class ImageFormat:
bmp: str = 'BMP' bmp: str = 'BMP'
png: str = 'PNG' png: str = 'PNG'
jpeg: str = 'JPEG' jpeg: str = 'JPEG'
raw: str = 'RAW'
def validate_image_format(format_str: str) -> str: def validate_image_format(format_str: str) -> str:
...@@ -66,8 +67,11 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes: ...@@ -66,8 +67,11 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
image_format: An enum specifying the format of the generated image. image_format: An enum specifying the format of the generated image.
Returns: Returns:
Encoded image string Encoded image string.
""" """
if image_format == 'RAW':
return image_np.tobytes()
if len(image_np.shape) > 2 and image_np.shape[2] == 1: if len(image_np.shape) > 2 and image_np.shape[2] == 1:
image_pil = Image.fromarray(np.squeeze(image_np), 'L') image_pil = Image.fromarray(np.squeeze(image_np), 'L')
else: else:
...@@ -77,7 +81,12 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes: ...@@ -77,7 +81,12 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
return output.getvalue() return output.getvalue()
def decode_image(image_bytes: bytes) -> np.ndarray: def decode_image(image_bytes: bytes,
image_format: Optional[str] = None,
image_dtype: str = 'uint8') -> np.ndarray:
"""Decodes image_bytes into numpy array."""
if image_format == 'RAW':
return np.frombuffer(image_bytes, dtype=image_dtype)
image_pil = Image.open(io.BytesIO(image_bytes)) image_pil = Image.open(io.BytesIO(image_bytes))
image_np = np.array(image_pil) image_np = np.array(image_pil)
if len(image_np.shape) < 3: if len(image_np.shape) < 3:
...@@ -88,6 +97,9 @@ def decode_image(image_bytes: bytes) -> np.ndarray: ...@@ -88,6 +97,9 @@ def decode_image(image_bytes: bytes) -> np.ndarray:
def decode_image_metadata(image_bytes: bytes) -> Tuple[int, int, int, str]: def decode_image_metadata(image_bytes: bytes) -> Tuple[int, int, int, str]:
"""Decodes image metadata from encoded image string. """Decodes image metadata from encoded image string.
Note that if the image is encoded in RAW format, the metadata cannot be
inferred from the image bytes.
Args: Args:
image_bytes: Encoded image string. image_bytes: Encoded image string.
......
...@@ -39,6 +39,21 @@ class ImageUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -39,6 +39,21 @@ class ImageUtilsTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllClose(actual_image_np, image_np) self.assertAllClose(actual_image_np, image_np)
self.assertEqual(actual_image_np.shape, image_np.shape) self.assertEqual(actual_image_np.shape, image_np.shape)
@parameterized.named_parameters(
('RGB_RAW', 128, 64, 3, tf.bfloat16.as_numpy_dtype),
('GREY_RAW', 32, 32, 1, tf.uint8.as_numpy_dtype))
def test_encode_raw_image_then_decode_raw_image(self, height, width,
num_channels, image_dtype):
image_np = fake_feature_generator.generate_image_np(height, width,
num_channels)
image_np = image_np.astype(image_dtype)
image_str = image_utils.encode_image(image_np, 'RAW')
actual_image_np = image_utils.decode_image(image_str, 'RAW', image_dtype)
actual_image_np = actual_image_np.reshape([height, width, num_channels])
self.assertAllClose(actual_image_np, image_np)
self.assertEqual(actual_image_np.shape, image_np.shape)
@parameterized.named_parameters( @parameterized.named_parameters(
('RGB_PNG', 128, 64, 3, 'PNG'), ('RGB_JPEG', 64, 128, 3, 'JPEG'), ('RGB_PNG', 128, 64, 3, 'PNG'), ('RGB_JPEG', 64, 128, 3, 'JPEG'),
('GREY_BMP', 32, 32, 1, 'BMP'), ('GREY_PNG', 128, 128, 1, 'png')) ('GREY_BMP', 32, 32, 1, 'BMP'), ('GREY_PNG', 128, 128, 1, 'png'))
......
...@@ -119,6 +119,10 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder): ...@@ -119,6 +119,10 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
Returns: Returns:
The builder object for subsequent method calls. The builder object for subsequent method calls.
""" """
if image_format == 'RAW':
if not (height and width and num_channels):
raise ValueError('For raw image feature, height, width and '
'num_channels fields are required.')
if not all((height, width, num_channels, image_format)): if not all((height, width, num_channels, image_format)):
(height, width, num_channels, image_format) = ( (height, width, num_channels, image_format) = (
image_utils.decode_image_metadata(encoded_image)) image_utils.decode_image_metadata(encoded_image))
......
...@@ -24,6 +24,7 @@ from official.vision.data import tf_example_builder ...@@ -24,6 +24,7 @@ from official.vision.data import tf_example_builder
class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('RGB_PNG', 128, 64, 3, 'PNG', '15125990'), @parameterized.named_parameters(('RGB_PNG', 128, 64, 3, 'PNG', '15125990'),
('RGB_RAW', 128, 128, 3, 'RAW', '5607919'),
('RGB_JPEG', 64, 128, 3, 'JPEG', '3107796')) ('RGB_JPEG', 64, 128, 3, 'JPEG', '3107796'))
def test_add_image_matrix_feature_success(self, height, width, num_channels, def test_add_image_matrix_feature_success(self, height, width, num_channels,
image_format, hashed_image): image_format, hashed_image):
...@@ -111,6 +112,62 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -111,6 +112,62 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[hashed_image])) bytes_list=tf.train.BytesList(value=[hashed_image]))
})), example) })), example)
def test_add_encoded_raw_image_feature_success(self):
height = 128
width = 128
num_channels = 3
image_format = 'RAW'
image_bytes = tf.bfloat16.as_numpy_dtype
image_np = fake_feature_generator.generate_image_np(height, width,
num_channels)
image_np = image_np.astype(image_bytes)
expected_image_bytes = image_utils.encode_image(image_np, image_format)
hashed_image = bytes('3572575', 'ascii')
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_encoded_image_feature(expected_image_bytes, 'RAW',
height, width, num_channels)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[expected_image_bytes])),
'image/format':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes(image_format, 'ascii')])),
'image/height':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[height])),
'image/width':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[width])),
'image/channels':
tf.train.Feature(
int64_list=tf.train.Int64List(
value=[num_channels])),
'image/source_id':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image]))
})), example)
def test_add_encoded_raw_image_feature_valueerror(self):
image_format = 'RAW'
image_bytes = tf.bfloat16.as_numpy_dtype
image_np = fake_feature_generator.generate_image_np(1, 1, 1)
image_np = image_np.astype(image_bytes)
expected_image_bytes = image_utils.encode_image(image_np, image_format)
example_builder = tf_example_builder.TfExampleBuilder()
with self.assertRaises(ValueError):
example_builder.add_encoded_image_feature(expected_image_bytes,
image_format)
@parameterized.parameters( @parameterized.parameters(
(True, True, True, True, True), (True, True, True, True, True),
(False, False, False, False, False), (False, False, False, False, False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment