Unverified Commit cf304238 authored by Bruce Fontaine's avatar Bruce Fontaine Committed by GitHub
Browse files

Add support for TPUEstimator to data processing pipeline and add the … (#6330)

* Add support for TPUEstimator to data processing pipeline and add the ability to store epochs in user specified location.
parent dadc4a62
...@@ -262,7 +262,7 @@ class DatasetManager(object): ...@@ -262,7 +262,7 @@ class DatasetManager(object):
file_pattern = os.path.join( file_pattern = os.path.join(
epoch_data_dir, rconst.SHARD_TEMPLATE.format("*")) epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
dataset = StreamingFilesDataset( dataset = StreamingFilesDataset(
files=file_pattern, worker_job="worker", files=file_pattern, worker_job=popen_helper.worker_job(),
num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1, num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
sloppy=not self._deterministic) sloppy=not self._deterministic)
map_fn = functools.partial(self._deserialize, batch_size=batch_size) map_fn = functools.partial(self._deserialize, batch_size=batch_size)
...@@ -297,8 +297,12 @@ class DatasetManager(object): ...@@ -297,8 +297,12 @@ class DatasetManager(object):
"""Create an input_fn which checks for batch size consistency.""" """Create an input_fn which checks for batch size consistency."""
def input_fn(params): def input_fn(params):
"""Returns batches for training."""
# Estimator passes batch_size during training and eval_batch_size during
# eval. TPUEstimator only passes batch_size.
param_batch_size = (params["batch_size"] if self._is_training else param_batch_size = (params["batch_size"] if self._is_training else
params["eval_batch_size"]) params.get("eval_batch_size") or params["batch_size"])
if batch_size != param_batch_size: if batch_size != param_batch_size:
raise ValueError("producer batch size ({}) differs from params batch " raise ValueError("producer batch size ({}) differs from params batch "
"size ({})".format(batch_size, param_batch_size)) "size ({})".format(batch_size, param_batch_size))
...@@ -338,7 +342,8 @@ class BaseDataConstructor(threading.Thread): ...@@ -338,7 +342,8 @@ class BaseDataConstructor(threading.Thread):
eval_batch_size, # type: int eval_batch_size, # type: int
batches_per_eval_step, # type: int batches_per_eval_step, # type: int
stream_files, # type: bool stream_files, # type: bool
deterministic=False # type: bool deterministic=False, # type: bool
epoch_dir=None # type: string
): ):
# General constants # General constants
self._maximum_number_epochs = maximum_number_epochs self._maximum_number_epochs = maximum_number_epochs
...@@ -382,7 +387,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -382,7 +387,7 @@ class BaseDataConstructor(threading.Thread):
self._shuffle_with_forkpool = not stream_files self._shuffle_with_forkpool = not stream_files
if stream_files: if stream_files:
self._shard_root = tempfile.mkdtemp(prefix="ncf_") self._shard_root = epoch_dir or tempfile.mkdtemp(prefix="ncf_")
atexit.register(tf.gfile.DeleteRecursively, dirname=self._shard_root) atexit.register(tf.gfile.DeleteRecursively, dirname=self._shard_root)
else: else:
self._shard_root = None self._shard_root = None
...@@ -648,9 +653,12 @@ class DummyConstructor(threading.Thread): ...@@ -648,9 +653,12 @@ class DummyConstructor(threading.Thread):
"""Construct training input_fn that uses synthetic data.""" """Construct training input_fn that uses synthetic data."""
def input_fn(params): def input_fn(params):
"""Generated input_fn for the given epoch.""" """Returns dummy input batches for training."""
# Estimator passes batch_size during training and eval_batch_size during
# eval. TPUEstimator only passes batch_size.
batch_size = (params["batch_size"] if is_training else batch_size = (params["batch_size"] if is_training else
params["eval_batch_size"]) params.get("eval_batch_size") or params["batch_size"])
num_users = params["num_users"] num_users = params["num_users"]
num_items = params["num_items"] num_items = params["num_items"]
......
...@@ -176,8 +176,8 @@ def _filter_index_sort(raw_rating_path, cache_path): ...@@ -176,8 +176,8 @@ def _filter_index_sort(raw_rating_path, cache_path):
def instantiate_pipeline(dataset, data_dir, params, constructor_type=None, def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
deterministic=False): deterministic=False, epoch_dir=None):
# type: (str, str, dict, typing.Optional[str], bool) -> (NCFDataset, typing.Callable) # type: (str, str, dict, typing.Optional[str], bool, typing.Optional[str]) -> (NCFDataset, typing.Callable)
"""Load and digest data CSV into a usable form. """Load and digest data CSV into a usable form.
Args: Args:
...@@ -187,6 +187,7 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None, ...@@ -187,6 +187,7 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
constructor_type: The name of the constructor subclass that should be used constructor_type: The name of the constructor subclass that should be used
for the input pipeline. for the input pipeline.
deterministic: Tell the data constructor to produce deterministically. deterministic: Tell the data constructor to produce deterministically.
epoch_dir: Directory in which to store the training epochs.
""" """
tf.logging.info("Beginning data preprocessing.") tf.logging.info("Beginning data preprocessing.")
...@@ -221,7 +222,8 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None, ...@@ -221,7 +222,8 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
eval_batch_size=params["eval_batch_size"], eval_batch_size=params["eval_batch_size"],
batches_per_eval_step=params["batches_per_step"], batches_per_eval_step=params["batches_per_step"],
stream_files=params["use_tpu"], stream_files=params["use_tpu"],
deterministic=deterministic deterministic=deterministic,
epoch_dir=epoch_dir
) )
run_time = timeit.default_timer() - st run_time = timeit.default_timer() - st
......
...@@ -58,3 +58,7 @@ class FauxPool(object): ...@@ -58,3 +58,7 @@ class FauxPool(object):
def get_fauxpool(num_workers, init_worker=None, closing=True): def get_fauxpool(num_workers, init_worker=None, closing=True):
pool = FauxPool(processes=num_workers, initializer=init_worker) pool = FauxPool(processes=num_workers, initializer=init_worker)
return contextlib.closing(pool) if closing else pool return contextlib.closing(pool) if closing else pool
def worker_job():
return "worker"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment