"vscode:/vscode.git/clone" did not exist on "1b781f15c8783acb9ce6e3ebcc8d90437567f7cf"
Commit 60568599 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add image labels to fake example generator

PiperOrigin-RevId: 475944607
parent 2eb655c4
......@@ -43,7 +43,8 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
image_matrix: np.ndarray,
image_format: str = 'PNG',
image_source_id: Optional[bytes] = None,
feature_prefix: Optional[str] = None) -> 'TfExampleBuilder':
feature_prefix: Optional[str] = None,
label: Optional[Union[int, Sequence[int]]] = None) -> 'TfExampleBuilder':
"""Encodes and adds image features to the example.
See `tf_example_feature_key.EncodedImageFeatureKey` for list of feature keys
......@@ -67,6 +68,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
image_source_id: Unique string ID to identify the image. Hashed image will
be used if the field is not provided.
feature_prefix: Feature prefix for image features.
label: the label or a list of labels for the image.
Returns:
The builder object for subsequent method calls.
......@@ -76,7 +78,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
return self.add_encoded_image_feature(encoded_image, image_format, height,
width, num_channels, image_source_id,
feature_prefix)
feature_prefix, label)
def add_encoded_image_feature(
self,
......@@ -86,7 +88,8 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
width: Optional[int] = None,
num_channels: Optional[int] = None,
image_source_id: Optional[bytes] = None,
feature_prefix: Optional[str] = None) -> 'TfExampleBuilder':
feature_prefix: Optional[str] = None,
label: Optional[Union[int, Sequence[int]]] = None) -> 'TfExampleBuilder':
"""Adds encoded image features to the example.
See `tf_example_feature_key.EncodedImageFeatureKey` for list of feature keys
......@@ -115,6 +118,7 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
num_channels: Number of channels.
image_source_id: Unique string ID to identify the image.
feature_prefix: Feature prefix for image features.
label: the label or a list of labels for the image.
Returns:
The builder object for subsequent method calls.
......@@ -138,6 +142,9 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
hashed_image = int(hashlib.blake2s(encoded_image).hexdigest(), 16)
image_source_id = _to_bytes(str(hash(hashed_image) % ((1 << 24) + 1)))
if label is not None:
self.add_ints_feature(feature_key.label, label)
return (
self.add_bytes_feature(feature_key.encoded, encoded_image)
.add_bytes_feature(feature_key.format, image_format)
......
......@@ -23,11 +23,12 @@ from official.vision.data import tf_example_builder
class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('RGB_PNG', 128, 64, 3, 'PNG', '15125990'),
('RGB_RAW', 128, 128, 3, 'RAW', '5607919'),
('RGB_JPEG', 64, 128, 3, 'JPEG', '3107796'))
@parameterized.named_parameters(
('RGB_PNG', 128, 64, 3, 'PNG', '15125990', 3),
('RGB_RAW', 128, 128, 3, 'RAW', '5607919', 0),
('RGB_JPEG', 64, 128, 3, 'JPEG', '3107796', [2, 5]))
def test_add_image_matrix_feature_success(self, height, width, num_channels,
image_format, hashed_image):
image_format, hashed_image, label):
# Prepare test data.
image_np = fake_feature_generator.generate_image_np(height, width,
num_channels)
......@@ -36,12 +37,17 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
# Run code logic.
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_image_matrix_feature(image_np, image_format)
example_builder.add_image_matrix_feature(image_np, image_format,
label=label)
example = example_builder.example
# Verify outputs.
# Prefer to use string literal for feature keys to directly display the
# structure of the expected tf.train.Example.
if isinstance(label, int):
expected_labels = [label]
else:
expected_labels = label
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
......@@ -66,7 +72,12 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
value=[num_channels])),
'image/source_id':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image]))
bytes_list=tf.train.BytesList(
value=[hashed_image])),
'image/class/label':
tf.train.Feature(
int64_list=tf.train.Int64List(
value=expected_labels)),
})), example)
def test_add_image_matrix_feature_with_feature_prefix_success(self):
......@@ -75,6 +86,7 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
num_channels = 1
image_format = 'PNG'
feature_prefix = 'depth'
label = 8
image_np = fake_feature_generator.generate_image_np(height, width,
num_channels)
expected_image_bytes = image_utils.encode_image(image_np, image_format)
......@@ -82,7 +94,7 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_image_matrix_feature(
image_np, image_format, feature_prefix=feature_prefix)
image_np, image_format, feature_prefix=feature_prefix, label=label)
example = example_builder.example
self.assertProtoEquals(
......@@ -109,7 +121,11 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
value=[num_channels])),
'depth/image/source_id':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image]))
bytes_list=tf.train.BytesList(
value=[hashed_image])),
'depth/image/class/label':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[label]))
})), example)
def test_add_encoded_raw_image_feature_success(self):
......@@ -169,18 +185,20 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
image_format)
@parameterized.parameters(
(True, True, True, True, True),
(False, False, False, False, False),
(True, False, False, False, False),
(False, True, False, False, False),
(False, False, True, False, False),
(False, False, False, True, False),
(False, False, False, False, True),
(True, True, True, True, True, True),
(False, False, False, False, False, False),
(True, False, False, False, False, False),
(False, True, False, False, False, False),
(False, False, True, False, False, False),
(False, False, False, True, False, False),
(False, False, False, False, True, False),
(False, False, False, False, False, True),
)
def test_add_encoded_image_feature_success(self, miss_image_format,
miss_height, miss_width,
miss_num_channels,
miss_image_source_id):
miss_image_source_id,
miss_label):
height = 64
width = 64
num_channels = 3
......@@ -189,12 +207,14 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
num_channels)
image_bytes = image_utils.encode_image(image_np, image_format)
hashed_image = bytes('2968688', 'ascii')
label = 5
image_format = None if miss_image_format else image_format
height = None if miss_height else height
width = None if miss_width else width
num_channels = None if miss_num_channels else num_channels
image_source_id = None if miss_image_source_id else hashed_image
label = None if miss_label else label
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_encoded_image_feature(
......@@ -203,33 +223,38 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
height=height,
width=width,
num_channels=num_channels,
image_source_id=image_source_id)
image_source_id=image_source_id,
label=label)
example = example_builder.example
expected_features = {
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image_bytes])),
'image/format':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes('PNG', 'ascii')])),
'image/height':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[64])),
'image/width':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[64])),
'image/channels':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[3])),
'image/source_id':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image]))}
if not miss_label:
expected_features.update({
'image/class/label':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[label]))})
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image_bytes])),
'image/format':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes('PNG', 'ascii')])),
'image/height':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[64])),
'image/width':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[64])),
'image/channels':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[3])),
'image/source_id':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[hashed_image]))
})), example)
tf.train.Example(features=tf.train.Features(feature=expected_features)),
example)
@parameterized.named_parameters(('no_box', 0), ('10_boxes', 10))
def test_add_normalized_boxes_feature(self, num_boxes):
......
......@@ -40,6 +40,7 @@ class EncodedImageFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
width: number of columns.
num_channels: number of channels.
source_id: Unique string ID to identify the image.
label: the label or a list of labels for the image.
"""
encoded: str = 'image/encoded'
format: str = 'image/format'
......@@ -47,6 +48,7 @@ class EncodedImageFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
width: str = 'image/width'
num_channels: str = 'image/channels'
source_id: str = 'image/source_id'
label: str = 'image/class/label'
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment