Commit b5286d69 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 441761682
parent 471451cc
...@@ -68,7 +68,7 @@ flags.DEFINE_boolean( ...@@ -68,7 +68,7 @@ flags.DEFINE_boolean(
'default: False.') 'default: False.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file') flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.') flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
_NUM_PROCESSES = flags.DEFINE_string( _NUM_PROCESSES = flags.DEFINE_integer(
'num_processes', None, 'num_processes', None,
('Number of parallel processes to use. ' ('Number of parallel processes to use. '
'If set to 0, disables multi-processing.')) 'If set to 0, disables multi-processing.'))
......
...@@ -26,6 +26,9 @@ import tensorflow as tf ...@@ -26,6 +26,9 @@ import tensorflow as tf
import multiprocessing as mp import multiprocessing as mp
LOG_EVERY = 100
def convert_to_feature(value, value_type=None): def convert_to_feature(value, value_type=None):
"""Converts the given python object to a tf.train.Feature. """Converts the given python object to a tf.train.Feature.
...@@ -147,7 +150,7 @@ def write_tf_record_dataset(output_path, annotation_iterator, ...@@ -147,7 +150,7 @@ def write_tf_record_dataset(output_path, annotation_iterator,
total_num_annotations_skipped = 0 total_num_annotations_skipped = 0
if multiple_processes is None or multiple_processes > 0: if multiple_processes is None or multiple_processes > 0:
pool = g3_mp.get_context(g3_mp.ABSL_FORKSERVER).Pool( pool = mp.Pool(
processes=multiple_processes) processes=multiple_processes)
if unpack_arguments: if unpack_arguments:
tf_example_iterator = pool.starmap(process_func, annotation_iterator) tf_example_iterator = pool.starmap(process_func, annotation_iterator)
...@@ -161,7 +164,7 @@ def write_tf_record_dataset(output_path, annotation_iterator, ...@@ -161,7 +164,7 @@ def write_tf_record_dataset(output_path, annotation_iterator,
for idx, (tf_example, num_annotations_skipped) in enumerate( for idx, (tf_example, num_annotations_skipped) in enumerate(
tf_example_iterator): tf_example_iterator):
if idx % 100 == 0: if idx % LOG_EVERY == 0:
logging.info('On image %d', idx) logging.info('On image %d', idx)
total_num_annotations_skipped += num_annotations_skipped total_num_annotations_skipped += num_annotations_skipped
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment