"megatron/mpu/vscode:/vscode.git/clone" did not exist on "b886b7bb972afe72bac0f5de4f42a4a7bae8ebef"
Unverified Commit b6c0c7f9 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Remove contrib thread pool. (#6175)

* Remove contrib thread pool.

* Remove commented out contrib import.

* Fix lint issues.

* move tf.data.options higher. Tweak line breaks.
parent 27e86174
......@@ -31,7 +31,6 @@ import os
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf
from tensorflow.contrib.data.python.ops import threadpool
from official.resnet import resnet_model
from official.utils.flags import core as flags_core
......@@ -75,6 +74,15 @@ def process_record_dataset(dataset,
Returns:
Dataset of (image, label) pairs ready for iteration.
"""
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
options = tf.data.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = (
datasets_num_private_threads)
dataset = dataset.with_options(options)
tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
......@@ -102,16 +110,6 @@ def process_record_dataset(dataset,
# on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
dataset = threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
datasets_num_private_threads,
display_name='input_pipeline_thread_pool'))
return dataset
......
......@@ -71,7 +71,8 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_tensor = tf.reshape(tf.compat.v1.train.get_or_create_global_step(), [1])
global_step_tensor = tf.reshape(
tf.compat.v1.train.get_or_create_global_step(), [1])
other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
return host_call_fn, [global_step_tensor] + other_tensors
......
......@@ -156,6 +156,7 @@ def _serialize_shards(df_shards, columns, pool, writer):
for example in s:
writer.write(example)
def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
"""Write a dataframe to a binary file for a dataset to consume.
......@@ -169,7 +170,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
Returns:
The path of the buffer.
"""
if tf.io.gfile.exists(buffer_path) and tf.io.gfile.stat(buffer_path).length > 0:
if (tf.io.gfile.exists(buffer_path) and
tf.io.gfile.stat(buffer_path).length > 0):
actual_size = tf.io.gfile.stat(buffer_path).length
if expected_size == actual_size:
return buffer_path
......@@ -184,7 +186,8 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
tf.io.gfile.makedirs(os.path.split(buffer_path)[0])
tf.compat.v1.logging.info("Constructing TFRecordDataset buffer: {}".format(buffer_path))
tf.compat.v1.logging.info("Constructing TFRecordDataset buffer: {}"
.format(buffer_path))
count = 0
pool = multiprocessing.Pool(multiprocessing.cpu_count())
......
......@@ -57,8 +57,9 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return []
if use_tpu:
tf.compat.v1.logging.warning("hooks_helper received name_list `{}`, but a TPU is "
"specified. No hooks will be used.".format(name_list))
tf.compat.v1.logging.warning('hooks_helper received name_list `{}`, but a '
'TPU is specified. No hooks will be used.'
.format(name_list))
return []
train_hooks = []
......@@ -142,6 +143,7 @@ def get_logging_metric_hook(tensors_to_log=None,
names. If not set, log _TENSORS_TO_LOG by default.
every_n_secs: `int`, the frequency for logging the metric. Default to every
10 mins.
**kwargs: a dictionary of arguments.
Returns:
Returns a LoggingMetricHook that saves tensor values in a JSON format.
......
......@@ -88,6 +88,6 @@ def generate_synthetic_data(
def apply_clean(flags_obj):
if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir: {}".format(
flags_obj.model_dir))
tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:"
" {}".format(flags_obj.model_dir))
tf.io.gfile.rmtree(flags_obj.model_dir)
......@@ -191,10 +191,11 @@ class BaseTest(tf.test.TestCase):
if correctness_function is not None:
results = correctness_function(*eval_results)
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "w") as f:
result_json = os.path.join(data_dir, "results.json")
with tf.io.gfile.GFile(result_json, "w") as f:
json.dump(results, f)
with tf.io.gfile.GFile(os.path.join(data_dir, "tf_version.json"), "w") as f:
tf_version_json = os.path.join(data_dir, "tf_version.json")
with tf.io.gfile.GFile(tf_version_json, "w") as f:
json.dump([tf.version.VERSION, tf.version.GIT_VERSION], f)
def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
......@@ -262,7 +263,8 @@ class BaseTest(tf.test.TestCase):
eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None:
results = correctness_function(*eval_results)
with tf.io.gfile.GFile(os.path.join(data_dir, "results.json"), "r") as f:
result_json = os.path.join(data_dir, "results.json")
with tf.io.gfile.GFile(result_json, "r") as f:
expected_results = json.load(f)
self.assertAllClose(results, expected_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment