Commit 374cff58 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 351798345
parent 1fb4c559
...@@ -19,6 +19,7 @@ import io ...@@ -19,6 +19,7 @@ import io
import itertools import itertools
from absl import logging from absl import logging
import numpy as np
from PIL import Image from PIL import Image
import tensorflow as tf import tensorflow as tf
...@@ -45,10 +46,10 @@ def convert_to_feature(value, value_type=None): ...@@ -45,10 +46,10 @@ def convert_to_feature(value, value_type=None):
if isinstance(element, bytes): if isinstance(element, bytes):
value_type = 'bytes' value_type = 'bytes'
elif isinstance(element, int): elif isinstance(element, (int, np.integer)):
value_type = 'int64' value_type = 'int64'
elif isinstance(element, float): elif isinstance(element, (float, np.floating)):
value_type = 'float' value_type = 'float'
else: else:
...@@ -104,8 +105,9 @@ def encode_binary_mask_as_png(binary_mask): ...@@ -104,8 +105,9 @@ def encode_binary_mask_as_png(binary_mask):
return output_io.getvalue() return output_io.getvalue()
def write_tf_record_dataset(output_path, annotation_iterator, process_func, def write_tf_record_dataset(output_path, annotation_iterator,
num_shards, use_multiprocessing=True): process_func, num_shards,
use_multiprocessing=True, unpack_arguments=True):
"""Iterates over annotations, processes them and writes into TFRecords. """Iterates over annotations, processes them and writes into TFRecords.
Args: Args:
...@@ -118,6 +120,9 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func, ...@@ -118,6 +120,9 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
num_shards: int, the number of shards to write for the dataset. num_shards: int, the number of shards to write for the dataset.
use_multiprocessing: use_multiprocessing:
Whether or not to use multiple processes to write TF Records. Whether or not to use multiple processes to write TF Records.
unpack_arguments:
Whether to unpack the tuples from annotation_iterator as individual
arguments to the process func or to pass the returned value as it is.
Returns: Returns:
num_skipped: The total number of skipped annotations. num_skipped: The total number of skipped annotations.
...@@ -133,9 +138,15 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func, ...@@ -133,9 +138,15 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
if use_multiprocessing: if use_multiprocessing:
pool = mp.Pool() pool = mp.Pool()
tf_example_iterator = pool.starmap(process_func, annotation_iterator) if unpack_arguments:
tf_example_iterator = pool.starmap(process_func, annotation_iterator)
else:
tf_example_iterator = pool.imap(process_func, annotation_iterator)
else: else:
tf_example_iterator = itertools.starmap(process_func, annotation_iterator) if unpack_arguments:
tf_example_iterator = itertools.starmap(process_func, annotation_iterator)
else:
tf_example_iterator = map(process_func, annotation_iterator)
for idx, (tf_example, num_annotations_skipped) in enumerate( for idx, (tf_example, num_annotations_skipped) in enumerate(
tf_example_iterator): tf_example_iterator):
...@@ -155,3 +166,10 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func, ...@@ -155,3 +166,10 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
logging.info('Finished writing, skipped %d annotations.', logging.info('Finished writing, skipped %d annotations.',
total_num_annotations_skipped) total_num_annotations_skipped)
return total_num_annotations_skipped return total_num_annotations_skipped
def check_and_make_dir(directory):
"""Creates the directory if it doesn't exist."""
if not tf.io.gfile.isdir(directory):
tf.io.gfile.makedirs(directory)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment