Commit 60568599 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add image labels to fake example generator

PiperOrigin-RevId: 475944607
parent 2eb655c4
...@@ -43,7 +43,8 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder): ...@@ -43,7 +43,8 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
image_matrix: np.ndarray, image_matrix: np.ndarray,
image_format: str = 'PNG', image_format: str = 'PNG',
image_source_id: Optional[bytes] = None, image_source_id: Optional[bytes] = None,
feature_prefix: Optional[str] = None) -> 'TfExampleBuilder': feature_prefix: Optional[str] = None,
label: Optional[Union[int, Sequence[int]]] = None) -> 'TfExampleBuilder':
"""Encodes and adds image features to the example. """Encodes and adds image features to the example.
See `tf_example_feature_key.EncodedImageFeatureKey` for list of feature keys See `tf_example_feature_key.EncodedImageFeatureKey` for list of feature keys
...@@ -67,6 +68,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder): ...@@ -67,6 +68,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
image_source_id: Unique string ID to identify the image. Hashed image will image_source_id: Unique string ID to identify the image. Hashed image will
be used if the field is not provided. be used if the field is not provided.
feature_prefix: Feature prefix for image features. feature_prefix: Feature prefix for image features.
label: the label or a list of labels for the image.
Returns: Returns:
The builder object for subsequent method calls. The builder object for subsequent method calls.
...@@ -76,7 +78,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder): ...@@ -76,7 +78,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
return self.add_encoded_image_feature(encoded_image, image_format, height, return self.add_encoded_image_feature(encoded_image, image_format, height,
width, num_channels, image_source_id, width, num_channels, image_source_id,
feature_prefix) feature_prefix, label)
def add_encoded_image_feature( def add_encoded_image_feature(
self, self,
...@@ -86,7 +88,8 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder): ...@@ -86,7 +88,8 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
width: Optional[int] = None, width: Optional[int] = None,
num_channels: Optional[int] = None, num_channels: Optional[int] = None,
image_source_id: Optional[bytes] = None, image_source_id: Optional[bytes] = None,
feature_prefix: Optional[str] = None) -> 'TfExampleBuilder': feature_prefix: Optional[str] = None,
label: Optional[Union[int, Sequence[int]]] = None) -> 'TfExampleBuilder':
"""Adds encoded image features to the example. """Adds encoded image features to the example.
See `tf_example_feature_key.EncodedImageFeatureKey` for list of feature keys See `tf_example_feature_key.EncodedImageFeatureKey` for list of feature keys
...@@ -115,6 +118,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder): ...@@ -115,6 +118,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
num_channels: Number of channels. num_channels: Number of channels.
image_source_id: Unique string ID to identify the image. image_source_id: Unique string ID to identify the image.
feature_prefix: Feature prefix for image features. feature_prefix: Feature prefix for image features.
label: the label or a list of labels for the image.
Returns: Returns:
The builder object for subsequent method calls. The builder object for subsequent method calls.
...@@ -138,6 +142,9 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder): ...@@ -138,6 +142,9 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
hashed_image = int(hashlib.blake2s(encoded_image).hexdigest(), 16) hashed_image = int(hashlib.blake2s(encoded_image).hexdigest(), 16)
image_source_id = _to_bytes(str(hash(hashed_image) % ((1 << 24) + 1))) image_source_id = _to_bytes(str(hash(hashed_image) % ((1 << 24) + 1)))
if label is not None:
self.add_ints_feature(feature_key.label, label)
return ( return (
self.add_bytes_feature(feature_key.encoded, encoded_image) self.add_bytes_feature(feature_key.encoded, encoded_image)
.add_bytes_feature(feature_key.format, image_format) .add_bytes_feature(feature_key.format, image_format)
......
...@@ -23,11 +23,12 @@ from official.vision.data import tf_example_builder ...@@ -23,11 +23,12 @@ from official.vision.data import tf_example_builder
class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('RGB_PNG', 128, 64, 3, 'PNG', '15125990'), @parameterized.named_parameters(
('RGB_RAW', 128, 128, 3, 'RAW', '5607919'), ('RGB_PNG', 128, 64, 3, 'PNG', '15125990', 3),
('RGB_JPEG', 64, 128, 3, 'JPEG', '3107796')) ('RGB_RAW', 128, 128, 3, 'RAW', '5607919', 0),
('RGB_JPEG', 64, 128, 3, 'JPEG', '3107796', [2, 5]))
def test_add_image_matrix_feature_success(self, height, width, num_channels, def test_add_image_matrix_feature_success(self, height, width, num_channels,
image_format, hashed_image): image_format, hashed_image, label):
# Prepare test data. # Prepare test data.
image_np = fake_feature_generator.generate_image_np(height, width, image_np = fake_feature_generator.generate_image_np(height, width,
num_channels) num_channels)
...@@ -36,12 +37,17 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -36,12 +37,17 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
# Run code logic. # Run code logic.
example_builder = tf_example_builder.TfExampleBuilder() example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_image_matrix_feature(image_np, image_format) example_builder.add_image_matrix_feature(image_np, image_format,
label=label)
example = example_builder.example example = example_builder.example
# Verify outputs. # Verify outputs.
# Prefer to use string literal for feature keys to directly display the # Prefer to use string literal for feature keys to directly display the
# structure of the expected tf.train.Example. # structure of the expected tf.train.Example.
if isinstance(label, int):
expected_labels = [label]
else:
expected_labels = label
self.assertProtoEquals( self.assertProtoEquals(
tf.train.Example( tf.train.Example(
features=tf.train.Features( features=tf.train.Features(
...@@ -66,7 +72,12 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,7 +72,12 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
value=[num_channels])), value=[num_channels])),
'image/source_id': 'image/source_id':
tf.train.Feature( tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image])) bytes_list=tf.train.BytesList(
value=[hashed_image])),
'image/class/label':
tf.train.Feature(
int64_list=tf.train.Int64List(
value=expected_labels)),
})), example) })), example)
def test_add_image_matrix_feature_with_feature_prefix_success(self): def test_add_image_matrix_feature_with_feature_prefix_success(self):
...@@ -75,6 +86,7 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -75,6 +86,7 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
num_channels = 1 num_channels = 1
image_format = 'PNG' image_format = 'PNG'
feature_prefix = 'depth' feature_prefix = 'depth'
label = 8
image_np = fake_feature_generator.generate_image_np(height, width, image_np = fake_feature_generator.generate_image_np(height, width,
num_channels) num_channels)
expected_image_bytes = image_utils.encode_image(image_np, image_format) expected_image_bytes = image_utils.encode_image(image_np, image_format)
...@@ -82,7 +94,7 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -82,7 +94,7 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
example_builder = tf_example_builder.TfExampleBuilder() example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_image_matrix_feature( example_builder.add_image_matrix_feature(
image_np, image_format, feature_prefix=feature_prefix) image_np, image_format, feature_prefix=feature_prefix, label=label)
example = example_builder.example example = example_builder.example
self.assertProtoEquals( self.assertProtoEquals(
...@@ -109,7 +121,11 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -109,7 +121,11 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
value=[num_channels])), value=[num_channels])),
'depth/image/source_id': 'depth/image/source_id':
tf.train.Feature( tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image])) bytes_list=tf.train.BytesList(
value=[hashed_image])),
'depth/image/class/label':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[label]))
})), example) })), example)
def test_add_encoded_raw_image_feature_success(self): def test_add_encoded_raw_image_feature_success(self):
...@@ -169,18 +185,20 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -169,18 +185,20 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
image_format) image_format)
@parameterized.parameters( @parameterized.parameters(
(True, True, True, True, True), (True, True, True, True, True, True),
(False, False, False, False, False), (False, False, False, False, False, False),
(True, False, False, False, False), (True, False, False, False, False, False),
(False, True, False, False, False), (False, True, False, False, False, False),
(False, False, True, False, False), (False, False, True, False, False, False),
(False, False, False, True, False), (False, False, False, True, False, False),
(False, False, False, False, True), (False, False, False, False, True, False),
(False, False, False, False, False, True),
) )
def test_add_encoded_image_feature_success(self, miss_image_format, def test_add_encoded_image_feature_success(self, miss_image_format,
miss_height, miss_width, miss_height, miss_width,
miss_num_channels, miss_num_channels,
miss_image_source_id): miss_image_source_id,
miss_label):
height = 64 height = 64
width = 64 width = 64
num_channels = 3 num_channels = 3
...@@ -189,12 +207,14 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -189,12 +207,14 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
num_channels) num_channels)
image_bytes = image_utils.encode_image(image_np, image_format) image_bytes = image_utils.encode_image(image_np, image_format)
hashed_image = bytes('2968688', 'ascii') hashed_image = bytes('2968688', 'ascii')
label = 5
image_format = None if miss_image_format else image_format image_format = None if miss_image_format else image_format
height = None if miss_height else height height = None if miss_height else height
width = None if miss_width else width width = None if miss_width else width
num_channels = None if miss_num_channels else num_channels num_channels = None if miss_num_channels else num_channels
image_source_id = None if miss_image_source_id else hashed_image image_source_id = None if miss_image_source_id else hashed_image
label = None if miss_label else label
example_builder = tf_example_builder.TfExampleBuilder() example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_encoded_image_feature( example_builder.add_encoded_image_feature(
...@@ -203,13 +223,11 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -203,13 +223,11 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
height=height, height=height,
width=width, width=width,
num_channels=num_channels, num_channels=num_channels,
image_source_id=image_source_id) image_source_id=image_source_id,
label=label)
example = example_builder.example example = example_builder.example
self.assertProtoEquals( expected_features = {
tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded': 'image/encoded':
tf.train.Feature( tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image_bytes])), bytes_list=tf.train.BytesList(value=[image_bytes])),
...@@ -228,8 +246,15 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -228,8 +246,15 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
int64_list=tf.train.Int64List(value=[3])), int64_list=tf.train.Int64List(value=[3])),
'image/source_id': 'image/source_id':
tf.train.Feature( tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image])) bytes_list=tf.train.BytesList(value=[hashed_image]))}
})), example) if not miss_label:
expected_features.update({
'image/class/label':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[label]))})
self.assertProtoEquals(
tf.train.Example(features=tf.train.Features(feature=expected_features)),
example)
@parameterized.named_parameters(('no_box', 0), ('10_boxes', 10)) @parameterized.named_parameters(('no_box', 0), ('10_boxes', 10))
def test_add_normalized_boxes_feature(self, num_boxes): def test_add_normalized_boxes_feature(self, num_boxes):
......
...@@ -40,6 +40,7 @@ class EncodedImageFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase): ...@@ -40,6 +40,7 @@ class EncodedImageFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
width: number of columns. width: number of columns.
num_channels: number of channels. num_channels: number of channels.
source_id: Unique string ID to identify the image. source_id: Unique string ID to identify the image.
label: the label or a list of labels for the image.
""" """
encoded: str = 'image/encoded' encoded: str = 'image/encoded'
format: str = 'image/format' format: str = 'image/format'
...@@ -47,6 +48,7 @@ class EncodedImageFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase): ...@@ -47,6 +48,7 @@ class EncodedImageFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
width: str = 'image/width' width: str = 'image/width'
num_channels: str = 'image/channels' num_channels: str = 'image/channels'
source_id: str = 'image/source_id' source_id: str = 'image/source_id'
label: str = 'image/class/label'
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment