Unverified Commit b6c0c7f9 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Remove contrib thread pool. (#6175)

* Remove contrib thread pool.

* Remove commented out contrib import.

* Fix lint issues.

* move tf.data.options higher. Tweak line breaks.
parent 27e86174
...@@ -31,7 +31,6 @@ import os ...@@ -31,7 +31,6 @@ import os
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.data.python.ops import threadpool
from official.resnet import resnet_model from official.resnet import resnet_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -75,6 +74,15 @@ def process_record_dataset(dataset, ...@@ -75,6 +74,15 @@ def process_record_dataset(dataset,
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
""" """
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
options = tf.data.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = (
datasets_num_private_threads)
dataset = dataset.with_options(options)
tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
# Prefetches a batch at a time to smooth out the time taken to load input # Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing. # files for shuffling and processing.
...@@ -102,16 +110,6 @@ def process_record_dataset(dataset, ...@@ -102,16 +110,6 @@ def process_record_dataset(dataset,
# on how many devices are present. # on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
dataset = threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
datasets_num_private_threads,
display_name='input_pipeline_thread_pool'))
return dataset return dataset
......
...@@ -71,7 +71,8 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""): ...@@ -71,7 +71,8 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to # dimension. These Tensors are implicitly concatenated to
# [params['batch_size']]. # [params['batch_size']].
global_step_tensor = tf.reshape(tf.compat.v1.train.get_or_create_global_step(), [1]) global_step_tensor = tf.reshape(
tf.compat.v1.train.get_or_create_global_step(), [1])
other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names] other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
return host_call_fn, [global_step_tensor] + other_tensors return host_call_fn, [global_step_tensor] + other_tensors
......
...@@ -156,6 +156,7 @@ def _serialize_shards(df_shards, columns, pool, writer): ...@@ -156,6 +156,7 @@ def _serialize_shards(df_shards, columns, pool, writer):
for example in s: for example in s:
writer.write(example) writer.write(example)
def write_to_buffer(dataframe, buffer_path, columns, expected_size=None): def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
"""Write a dataframe to a binary file for a dataset to consume. """Write a dataframe to a binary file for a dataset to consume.
...@@ -169,7 +170,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None): ...@@ -169,7 +170,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
Returns: Returns:
The path of the buffer. The path of the buffer.
""" """
if tf.io.gfile.exists(buffer_path) and tf.io.gfile.stat(buffer_path).length > 0: if (tf.io.gfile.exists(buffer_path) and
tf.io.gfile.stat(buffer_path).length > 0):
actual_size = tf.io.gfile.stat(buffer_path).length actual_size = tf.io.gfile.stat(buffer_path).length
if expected_size == actual_size: if expected_size == actual_size:
return buffer_path return buffer_path
...@@ -184,7 +186,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None): ...@@ -184,7 +186,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
tf.io.gfile.makedirs(os.path.split(buffer_path)[0]) tf.io.gfile.makedirs(os.path.split(buffer_path)[0])
tf.compat.v1.logging.info("Constructing TFRecordDataset buffer: {}".format(buffer_path)) tf.compat.v1.logging.info("Constructing TFRecordDataset buffer: {}"
.format(buffer_path))
count = 0 count = 0
pool = multiprocessing.Pool(multiprocessing.cpu_count()) pool = multiprocessing.Pool(multiprocessing.cpu_count())
...@@ -195,7 +198,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None): ...@@ -195,7 +198,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
_serialize_shards(df_shards, columns, pool, writer) _serialize_shards(df_shards, columns, pool, writer)
count += sum([len(s) for s in df_shards]) count += sum([len(s) for s in df_shards])
tf.compat.v1.logging.info("{}/{} examples written." tf.compat.v1.logging.info("{}/{} examples written."
.format(str(count).ljust(8), len(dataframe))) .format(str(count).ljust(8), len(dataframe)))
finally: finally:
pool.terminate() pool.terminate()
......
...@@ -57,8 +57,9 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs): ...@@ -57,8 +57,9 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return [] return []
if use_tpu: if use_tpu:
tf.compat.v1.logging.warning("hooks_helper received name_list `{}`, but a TPU is " tf.compat.v1.logging.warning('hooks_helper received name_list `{}`, but a '
"specified. No hooks will be used.".format(name_list)) 'TPU is specified. No hooks will be used.'
.format(name_list))
return [] return []
train_hooks = [] train_hooks = []
...@@ -142,6 +143,7 @@ def get_logging_metric_hook(tensors_to_log=None, ...@@ -142,6 +143,7 @@ def get_logging_metric_hook(tensors_to_log=None,
names. If not set, log _TENSORS_TO_LOG by default. names. If not set, log _TENSORS_TO_LOG by default.
every_n_secs: `int`, the frequency for logging the metric. Default to every every_n_secs: `int`, the frequency for logging the metric. Default to every
10 mins. 10 mins.
**kwargs: a dictionary of arguments.
Returns: Returns:
Returns a LoggingMetricHook that saves tensor values in a JSON format. Returns a LoggingMetricHook that saves tensor values in a JSON format.
......
...@@ -88,6 +88,6 @@ def generate_synthetic_data( ...@@ -88,6 +88,6 @@ def generate_synthetic_data(
def apply_clean(flags_obj): def apply_clean(flags_obj):
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir): if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir: {}".format( tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:"
flags_obj.model_dir)) " {}".format(flags_obj.model_dir))
tf.io.gfile.rmtree(flags_obj.model_dir) tf.io.gfile.rmtree(flags_obj.model_dir)
...@@ -191,10 +191,11 @@ class BaseTest(tf.test.TestCase): ...@@ -191,10 +191,11 @@ class BaseTest(tf.test.TestCase):
if correctness_function is not None: if correctness_function is not None:
results = correctness_function(*eval_results) results = correctness_function(*eval_results)
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "w") as f: result_json = os.path.join(data_dir, "results.json")
with tf.io.gfile.GFile(result_json, "w") as f:
json.dump(results, f) json.dump(results, f)
tf_version_json = os.path.join(data_dir, "tf_version.json")
with tf.io.gfile.GFile(os.path.join(data_dir, "tf_version.json"), "w") as f: with tf.io.gfile.GFile(tf_version_json, "w") as f:
json.dump([tf.version.VERSION, tf.version.GIT_VERSION], f) json.dump([tf.version.VERSION, tf.version.GIT_VERSION], f)
def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function): def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
...@@ -262,7 +263,8 @@ class BaseTest(tf.test.TestCase): ...@@ -262,7 +263,8 @@ class BaseTest(tf.test.TestCase):
eval_results = [op.eval() for op in ops_to_eval] eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None: if correctness_function is not None:
results = correctness_function(*eval_results) results = correctness_function(*eval_results)
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "r") as f: result_json = os.path.join(data_dir, "results.json")
with tf.io.gfile.GFile(result_json, "r") as f:
expected_results = json.load(f) expected_results = json.load(f)
self.assertAllClose(results, expected_results) self.assertAllClose(results, expected_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment