Commit 374cff58 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 351798345
parent 1fb4c559
......@@ -19,6 +19,7 @@ import io
import itertools
from absl import logging
import numpy as np
from PIL import Image
import tensorflow as tf
......@@ -45,10 +46,10 @@ def convert_to_feature(value, value_type=None):
if isinstance(element, bytes):
value_type = 'bytes'
elif isinstance(element, int):
elif isinstance(element, (int, np.integer)):
value_type = 'int64'
elif isinstance(element, float):
elif isinstance(element, (float, np.floating)):
value_type = 'float'
else:
......@@ -104,8 +105,9 @@ def encode_binary_mask_as_png(binary_mask):
return output_io.getvalue()
def write_tf_record_dataset(output_path, annotation_iterator, process_func,
num_shards, use_multiprocessing=True):
def write_tf_record_dataset(output_path, annotation_iterator,
process_func, num_shards,
use_multiprocessing=True, unpack_arguments=True):
"""Iterates over annotations, processes them and writes into TFRecords.
Args:
......@@ -118,6 +120,9 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
num_shards: int, the number of shards to write for the dataset.
use_multiprocessing:
Whether or not to use multiple processes to write TF Records.
unpack_arguments:
Whether to unpack the tuples from annotation_iterator as individual
arguments to the process func or to pass the returned value as it is.
Returns:
num_skipped: The total number of skipped annotations.
......@@ -133,9 +138,15 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
if use_multiprocessing:
pool = mp.Pool()
tf_example_iterator = pool.starmap(process_func, annotation_iterator)
if unpack_arguments:
tf_example_iterator = pool.starmap(process_func, annotation_iterator)
else:
tf_example_iterator = pool.imap(process_func, annotation_iterator)
else:
tf_example_iterator = itertools.starmap(process_func, annotation_iterator)
if unpack_arguments:
tf_example_iterator = itertools.starmap(process_func, annotation_iterator)
else:
tf_example_iterator = map(process_func, annotation_iterator)
for idx, (tf_example, num_annotations_skipped) in enumerate(
tf_example_iterator):
......@@ -155,3 +166,10 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
logging.info('Finished writing, skipped %d annotations.',
total_num_annotations_skipped)
return total_num_annotations_skipped
def check_and_make_dir(directory):
"""Creates the directory if it doesn't exist."""
if not tf.io.gfile.isdir(directory):
tf.io.gfile.makedirs(directory)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment