"examples/vscode:/vscode.git/clone" did not exist on "2ecd2b23ec7e8f48c0e7286dad306d7265e17a29"
Commit daefafa7 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370768021
parent ab8192ee
......@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
dtype: str = 'float32'
shuffle_buffer_size: int = 10000
cycle_length: int = 10
is_multilabel: bool = False
aug_rand_hflip: bool = True
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
......
......@@ -32,11 +32,18 @@ class Decoder(decoder.Decoder):
def __init__(self,
image_field_key: str = 'image/encoded',
label_field_key: str = 'image/class/label'):
label_field_key: str = 'image/class/label',
is_multilabel: bool = False):
self._keys_to_features = {
image_field_key: tf.io.FixedLenFeature((), tf.string, default_value=''),
label_field_key: (tf.io.FixedLenFeature((), tf.int64, default_value=-1))
}
if is_multilabel:
self._keys_to_features.update(
{label_field_key: tf.io.VarLenFeature(dtype=tf.int64)})
else:
self._keys_to_features.update({
label_field_key: tf.io.FixedLenFeature((), tf.int64, default_value=-1)
})
def decode(self,
serialized_example: tf.train.Example) -> Dict[str, tf.Tensor]:
......@@ -54,6 +61,7 @@ class Parser(parser.Parser):
label_field_key: str = 'image/class/label',
aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None,
is_multilabel: bool = False,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -67,6 +75,7 @@ class Parser(parser.Parser):
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
......@@ -75,6 +84,7 @@ class Parser(parser.Parser):
self._num_classes = num_classes
self._image_field_key = image_field_key
self._label_field_key = label_field_key
self._is_multilabel = is_multilabel
if dtype == 'float32':
self._dtype = tf.float32
......@@ -136,6 +146,11 @@ class Parser(parser.Parser):
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
if self._is_multilabel:
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
def _parse_eval_data(self, decoded_tensors):
......@@ -160,4 +175,9 @@ class Parser(parser.Parser):
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
if self._is_multilabel:
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
......@@ -46,12 +46,12 @@ class FooTrainTest(tf.test.TestCase):
import io
from typing import Sequence, Union
# Import libraries
import numpy as np
from PIL import Image
import tensorflow as tf
IMAGE_KEY = 'image/encoded'
CLASSIFICATION_LABEL_KEY = 'image/class/label'
LABEL_KEY = 'clip/label/index'
AUDIO_KEY = 'features/audio'
......@@ -114,3 +114,30 @@ def dump_to_tfrecord(record_file: str,
with tf.io.TFRecordWriter(record_file) as writer:
for tf_example in tf_examples:
writer.write(tf_example.SerializeToString())
def _encode_image(image_array: np.ndarray, fmt: str) -> bytes:
"""Util function to encode an image."""
image = Image.fromarray(image_array)
with io.BytesIO() as output:
image.save(output, format=fmt)
return output.getvalue()
def create_classification_example(
image_height: int,
image_width: int,
is_multilabel: bool = False) -> tf.train.Example:
"""Creates image and labels for image classification input pipeline."""
image = _encode_image(
np.uint8(np.random.rand(image_height, image_width, 3) * 255), fmt='JPEG')
labels = [0, 1] if is_multilabel else [0]
serialized_example = tf.train.Example(
features=tf.train.Features(
feature={
IMAGE_KEY: (tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image]))),
CLASSIFICATION_LABEL_KEY: (tf.train.Feature(
int64_list=tf.train.Int64List(value=labels))),
})).SerializeToString()
return serialized_example
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment