Unverified Commit cf304238 authored by Bruce Fontaine's avatar Bruce Fontaine Committed by GitHub
Browse files

Add support for TPUEstimator to data processing pipeline and add the … (#6330)

* Add support for TPUEstimator to data processing pipeline and add the ability to store epochs in user specified location.
parent dadc4a62
......@@ -262,7 +262,7 @@ class DatasetManager(object):
file_pattern = os.path.join(
epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
dataset = StreamingFilesDataset(
files=file_pattern, worker_job="worker",
files=file_pattern, worker_job=popen_helper.worker_job(),
num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
sloppy=not self._deterministic)
map_fn = functools.partial(self._deserialize, batch_size=batch_size)
......@@ -297,8 +297,12 @@ class DatasetManager(object):
"""Create an input_fn which checks for batch size consistency."""
def input_fn(params):
"""Returns batches for training."""
# Estimator passes batch_size during training and eval_batch_size during
# eval. TPUEstimator only passes batch_size.
param_batch_size = (params["batch_size"] if self._is_training else
params["eval_batch_size"])
params.get("eval_batch_size") or params["batch_size"])
if batch_size != param_batch_size:
raise ValueError("producer batch size ({}) differs from params batch "
"size ({})".format(batch_size, param_batch_size))
......@@ -338,7 +342,8 @@ class BaseDataConstructor(threading.Thread):
eval_batch_size, # type: int
batches_per_eval_step, # type: int
stream_files, # type: bool
deterministic=False # type: bool
deterministic=False, # type: bool
epoch_dir=None # type: string
):
# General constants
self._maximum_number_epochs = maximum_number_epochs
......@@ -382,7 +387,7 @@ class BaseDataConstructor(threading.Thread):
self._shuffle_with_forkpool = not stream_files
if stream_files:
self._shard_root = tempfile.mkdtemp(prefix="ncf_")
self._shard_root = epoch_dir or tempfile.mkdtemp(prefix="ncf_")
atexit.register(tf.gfile.DeleteRecursively, dirname=self._shard_root)
else:
self._shard_root = None
......@@ -648,9 +653,12 @@ class DummyConstructor(threading.Thread):
"""Construct training input_fn that uses synthetic data."""
def input_fn(params):
"""Generated input_fn for the given epoch."""
"""Returns dummy input batches for training."""
# Estimator passes batch_size during training and eval_batch_size during
# eval. TPUEstimator only passes batch_size.
batch_size = (params["batch_size"] if is_training else
params["eval_batch_size"])
params.get("eval_batch_size") or params["batch_size"])
num_users = params["num_users"]
num_items = params["num_items"]
......
......@@ -176,8 +176,8 @@ def _filter_index_sort(raw_rating_path, cache_path):
def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
deterministic=False):
# type: (str, str, dict, typing.Optional[str], bool) -> (NCFDataset, typing.Callable)
deterministic=False, epoch_dir=None):
# type: (str, str, dict, typing.Optional[str], bool, typing.Optional[str]) -> (NCFDataset, typing.Callable)
"""Load and digest data CSV into a usable form.
Args:
......@@ -187,6 +187,7 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
constructor_type: The name of the constructor subclass that should be used
for the input pipeline.
deterministic: Tell the data constructor to produce deterministically.
epoch_dir: Directory in which to store the training epochs.
"""
tf.logging.info("Beginning data preprocessing.")
......@@ -221,7 +222,8 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
eval_batch_size=params["eval_batch_size"],
batches_per_eval_step=params["batches_per_step"],
stream_files=params["use_tpu"],
deterministic=deterministic
deterministic=deterministic,
epoch_dir=epoch_dir
)
run_time = timeit.default_timer() - st
......
......@@ -58,3 +58,7 @@ class FauxPool(object):
def get_fauxpool(num_workers, init_worker=None, closing=True):
pool = FauxPool(processes=num_workers, initializer=init_worker)
return contextlib.closing(pool) if closing else pool
def worker_job():
return "worker"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment