Commit 9d42f797 authored by Taylor Robie's avatar Taylor Robie
Browse files

remove 'deterministic'

parent c5ff4ec7
...@@ -197,15 +197,14 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf): ...@@ -197,15 +197,14 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
return data, valid_cache return data, valid_cache
def instantiate_pipeline(dataset, data_dir, deterministic, params): def instantiate_pipeline(dataset, data_dir, params):
# type: (str, str, bool, dict) -> (NCFDataset, typing.Callable) # type: (str, str, dict) -> (NCFDataset, typing.Callable)
"""Load and digest data CSV into a usable form. """Load and digest data CSV into a usable form.
Args: Args:
dataset: The name of the dataset to be used. dataset: The name of the dataset to be used.
data_dir: The root directory of the dataset. data_dir: The root directory of the dataset.
deterministic: Try to enforce repeatable behavior, even at the cost of params: dict of parameters for the run.
performance.
""" """
tf.logging.info("Beginning data preprocessing.") tf.logging.info("Beginning data preprocessing.")
...@@ -225,9 +224,6 @@ def instantiate_pipeline(dataset, data_dir, deterministic, params): ...@@ -225,9 +224,6 @@ def instantiate_pipeline(dataset, data_dir, deterministic, params):
raise ValueError("Expected to find {} items, but found {}".format( raise ValueError("Expected to find {} items, but found {}".format(
num_items, len(item_map))) num_items, len(item_map)))
if deterministic:
raise NotImplementedError("Fixed seed behavior has not been implemented.")
producer = data_pipeline.MaterializedDataConstructor( producer = data_pipeline.MaterializedDataConstructor(
maximum_number_epochs=params["train_epochs"], maximum_number_epochs=params["train_epochs"],
num_users=num_users, num_users=num_users,
......
...@@ -195,6 +195,8 @@ def run_ncf(_): ...@@ -195,6 +195,8 @@ def run_ncf(_):
if FLAGS.seed is not None: if FLAGS.seed is not None:
np.random.seed(FLAGS.seed) np.random.seed(FLAGS.seed)
tf.logging.warning("Values may still vary from run to run due to thread "
"execution ordering.")
params = parse_flags(FLAGS) params = parse_flags(FLAGS)
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
...@@ -207,8 +209,7 @@ def run_ncf(_): ...@@ -207,8 +209,7 @@ def run_ncf(_):
num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
else: else:
num_users, num_items, producer = data_preprocessing.instantiate_pipeline( num_users, num_items, producer = data_preprocessing.instantiate_pipeline(
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params)
deterministic=FLAGS.seed is not None, params=params)
num_train_steps = (producer.train_batches_per_epoch // num_train_steps = (producer.train_batches_per_epoch //
params["batches_per_step"]) params["batches_per_step"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment