Commit 54659689 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Make # of processes configurable in create_coco_tf_record.py

PiperOrigin-RevId: 441311717
parent 7fb0abb2
......@@ -68,6 +68,11 @@ flags.DEFINE_boolean(
'default: False.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
_NUM_PROCESSES = flags.DEFINE_string(
'num_processes', None,
('Number of parallel processes to use. '
'If set to 0, disables multi-processing.'))
FLAGS = flags.FLAGS
......@@ -518,7 +523,8 @@ def _create_tf_record_from_coco_annotations(images_info_file,
include_masks=include_masks)
num_skipped = tfrecord_lib.write_tf_record_dataset(
output_path, coco_annotations_iter, create_tf_example, num_shards)
output_path, coco_annotations_iter, create_tf_example, num_shards,
multiple_processes=_NUM_PROCESSES.value)
logging.info('Finished writing, skipped %d annotations.', num_skipped)
......
......@@ -114,7 +114,7 @@ def encode_mask_as_png(mask):
def write_tf_record_dataset(output_path, annotation_iterator,
process_func, num_shards,
use_multiprocessing=True, unpack_arguments=True):
multiple_processes=None, unpack_arguments=True):
"""Iterates over annotations, processes them and writes into TFRecords.
Args:
......@@ -125,7 +125,10 @@ def write_tf_record_dataset(output_path, annotation_iterator,
annotation_iterator as arguments and returns a tuple of (tf.train.Example,
int). The integer indicates the number of annotations that were skipped.
num_shards: int, the number of shards to write for the dataset.
use_multiprocessing:
multiple_processes: integer, the number of multiple parallel processes to
use. If None, uses multi-processing with number of processes equal to
`os.cpu_count()`, which is Python's default behavior. If set to 0,
multi-processing is disabled.
Whether or not to use multiple processes to write TF Records.
unpack_arguments:
Whether to unpack the tuples from annotation_iterator as individual
......@@ -143,8 +146,9 @@ def write_tf_record_dataset(output_path, annotation_iterator,
total_num_annotations_skipped = 0
if use_multiprocessing:
pool = mp.Pool()
if multiple_processes is None or multiple_processes > 0:
pool = g3_mp.get_context(g3_mp.ABSL_FORKSERVER).Pool(
processes=multiple_processes)
if unpack_arguments:
tf_example_iterator = pool.starmap(process_func, annotation_iterator)
else:
......@@ -163,7 +167,7 @@ def write_tf_record_dataset(output_path, annotation_iterator,
total_num_annotations_skipped += num_annotations_skipped
writers[idx % num_shards].write(tf_example.SerializeToString())
if use_multiprocessing:
if multiple_processes is None or multiple_processes > 0:
pool.close()
pool.join()
......
......@@ -47,7 +47,7 @@ class TfrecordLibTest(parameterized.TestCase):
path = os.path.join(FLAGS.test_tmpdir, 'train')
tfrecord_lib.write_tf_record_dataset(
path, data, process_sample, 3, use_multiprocessing=False)
path, data, process_sample, 3, multiple_processes=0)
tfrecord_files = tf.io.gfile.glob(path + '*')
self.assertLen(tfrecord_files, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment