Commit b5286d69 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 441761682
parent 471451cc
......@@ -68,7 +68,7 @@ flags.DEFINE_boolean(
'default: False.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
_NUM_PROCESSES = flags.DEFINE_string(
_NUM_PROCESSES = flags.DEFINE_integer(
'num_processes', None,
('Number of parallel processes to use. '
'If set to 0, disables multi-processing.'))
......
......@@ -26,6 +26,9 @@ import tensorflow as tf
import multiprocessing as mp
LOG_EVERY = 100
def convert_to_feature(value, value_type=None):
"""Converts the given python object to a tf.train.Feature.
......@@ -147,7 +150,7 @@ def write_tf_record_dataset(output_path, annotation_iterator,
total_num_annotations_skipped = 0
if multiple_processes is None or multiple_processes > 0:
pool = g3_mp.get_context(g3_mp.ABSL_FORKSERVER).Pool(
pool = mp.Pool(
processes=multiple_processes)
if unpack_arguments:
tf_example_iterator = pool.starmap(process_func, annotation_iterator)
......@@ -161,7 +164,7 @@ def write_tf_record_dataset(output_path, annotation_iterator,
for idx, (tf_example, num_annotations_skipped) in enumerate(
tf_example_iterator):
if idx % 100 == 0:
if idx % LOG_EVERY == 0:
logging.info('On image %d', idx)
total_num_annotations_skipped += num_annotations_skipped
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment