"tests/git@developer.sourcefind.cn:OpenDAS/d2go.git" did not exist on "1c9e0e83b5b268cc7a90c5ac9d6142adb7d304da"
Commit 4fb325da authored by Taylor Robie's avatar Taylor Robie
Browse files

Add bisection based producer for increased scalability, enable fully...

Add bisection based producer for increased scalability, enable fully deterministic data production, and use the materialized and bisection producer to check each other (via expected output md5's)
parent 1048ffd5
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import atexit import atexit
import collections
import functools import functools
import os import os
import sys import sys
...@@ -79,8 +80,8 @@ class DatasetManager(object): ...@@ -79,8 +80,8 @@ class DatasetManager(object):
management, tf.Dataset creation, etc.). management, tf.Dataset creation, etc.).
""" """
def __init__(self, is_training, stream_files, batches_per_epoch, def __init__(self, is_training, stream_files, batches_per_epoch,
shard_root=None): shard_root=None, deterministic=False):
# type: (bool, bool, int, typing.Optional[str]) -> None # type: (bool, bool, int, typing.Optional[str], bool) -> None
"""Constructs a `DatasetManager` instance. """Constructs a `DatasetManager` instance.
Args: Args:
is_training: Boolean of whether the data provided is training or is_training: Boolean of whether the data provided is training or
...@@ -91,8 +92,10 @@ class DatasetManager(object): ...@@ -91,8 +92,10 @@ class DatasetManager(object):
written to file shards. written to file shards.
batches_per_epoch: The number of batches in a single epoch. batches_per_epoch: The number of batches in a single epoch.
shard_root: The base directory to be used when stream_files=True. shard_root: The base directory to be used when stream_files=True.
deterministic: Forgo non-deterministic speedups. (i.e. sloppy=True)
""" """
self._is_training = is_training self._is_training = is_training
self._deterministic = deterministic
self._stream_files = stream_files self._stream_files = stream_files
self._writers = [] self._writers = []
self._write_locks = [threading.RLock() for _ in self._write_locks = [threading.RLock() for _ in
...@@ -259,7 +262,8 @@ class DatasetManager(object): ...@@ -259,7 +262,8 @@ class DatasetManager(object):
epoch_data_dir, rconst.SHARD_TEMPLATE.format("*")) epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
dataset = StreamingFilesDataset( dataset = StreamingFilesDataset(
files=file_pattern, worker_job="worker", files=file_pattern, worker_job="worker",
num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1) num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
sloppy=not self._deterministic)
map_fn = functools.partial(self._deserialize, batch_size=batch_size) map_fn = functools.partial(self._deserialize, batch_size=batch_size)
dataset = dataset.map(map_fn, num_parallel_calls=16) dataset = dataset.map(map_fn, num_parallel_calls=16)
...@@ -332,7 +336,8 @@ class BaseDataConstructor(threading.Thread): ...@@ -332,7 +336,8 @@ class BaseDataConstructor(threading.Thread):
eval_pos_items, # type: np.ndarray eval_pos_items, # type: np.ndarray
eval_batch_size, # type: int eval_batch_size, # type: int
batches_per_eval_step, # type: int batches_per_eval_step, # type: int
stream_files # type: bool stream_files, # type: bool
deterministic=False # type: bool
): ):
# General constants # General constants
self._maximum_number_epochs = maximum_number_epochs self._maximum_number_epochs = maximum_number_epochs
...@@ -381,15 +386,18 @@ class BaseDataConstructor(threading.Thread): ...@@ -381,15 +386,18 @@ class BaseDataConstructor(threading.Thread):
self._shard_root = None self._shard_root = None
self._train_dataset = DatasetManager( self._train_dataset = DatasetManager(
True, stream_files, self.train_batches_per_epoch, self._shard_root) True, stream_files, self.train_batches_per_epoch, self._shard_root,
deterministic)
self._eval_dataset = DatasetManager( self._eval_dataset = DatasetManager(
False, stream_files, self.eval_batches_per_epoch, self._shard_root) False, stream_files, self.eval_batches_per_epoch, self._shard_root,
deterministic)
# Threading details # Threading details
super(BaseDataConstructor, self).__init__() super(BaseDataConstructor, self).__init__()
self.daemon = True self.daemon = True
self._stop_loop = False self._stop_loop = False
self._fatal_exception = None self._fatal_exception = None
self.deterministic = deterministic
def __str__(self): def __str__(self):
multiplier = ("(x{} devices)".format(self._batches_per_train_step) multiplier = ("(x{} devices)".format(self._batches_per_train_step)
...@@ -428,6 +436,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -428,6 +436,7 @@ class BaseDataConstructor(threading.Thread):
self._construct_eval_epoch() self._construct_eval_epoch()
for _ in range(self._maximum_number_epochs - 1): for _ in range(self._maximum_number_epochs - 1):
self._construct_training_epoch() self._construct_training_epoch()
self.stop_loop()
def run(self): def run(self):
try: try:
...@@ -445,7 +454,8 @@ class BaseDataConstructor(threading.Thread): ...@@ -445,7 +454,8 @@ class BaseDataConstructor(threading.Thread):
atexit.register(pool.close) atexit.register(pool.close)
args = [(self._elements_in_epoch, stat_utils.random_int32()) args = [(self._elements_in_epoch, stat_utils.random_int32())
for _ in range(self._maximum_number_epochs)] for _ in range(self._maximum_number_epochs)]
self._shuffle_iterator = pool.imap_unordered(stat_utils.permutation, args) imap = pool.imap if self.deterministic else pool.imap_unordered
self._shuffle_iterator = imap(stat_utils.permutation, args)
def _get_training_batch(self, i): def _get_training_batch(self, i):
"""Construct a single batch of training data. """Construct a single batch of training data.
...@@ -511,7 +521,9 @@ class BaseDataConstructor(threading.Thread): ...@@ -511,7 +521,9 @@ class BaseDataConstructor(threading.Thread):
map_args = list(range(self.train_batches_per_epoch)) map_args = list(range(self.train_batches_per_epoch))
self._current_epoch_order = next(self._shuffle_iterator) self._current_epoch_order = next(self._shuffle_iterator)
with popen_helper.get_threadpool(6) as pool: get_pool = (popen_helper.get_fauxpool if self.deterministic else
popen_helper.get_threadpool)
with get_pool(6) as pool:
pool.map(self._get_training_batch, map_args) pool.map(self._get_training_batch, map_args)
self._train_dataset.end_construction() self._train_dataset.end_construction()
...@@ -590,7 +602,10 @@ class BaseDataConstructor(threading.Thread): ...@@ -590,7 +602,10 @@ class BaseDataConstructor(threading.Thread):
self._eval_dataset.start_construction() self._eval_dataset.start_construction()
map_args = [i for i in range(self.eval_batches_per_epoch)] map_args = [i for i in range(self.eval_batches_per_epoch)]
with popen_helper.get_threadpool(6) as pool:
get_pool = (popen_helper.get_fauxpool if self.deterministic else
popen_helper.get_threadpool)
with get_pool(6) as pool:
pool.map(self._get_eval_batch, map_args) pool.map(self._get_eval_batch, map_args)
self._eval_dataset.end_construction() self._eval_dataset.end_construction()
...@@ -733,3 +748,119 @@ class MaterializedDataConstructor(BaseDataConstructor): ...@@ -733,3 +748,119 @@ class MaterializedDataConstructor(BaseDataConstructor):
negative_item_choice = stat_utils.very_slightly_biased_randint( negative_item_choice = stat_utils.very_slightly_biased_randint(
self._per_user_neg_count[negative_users]) self._per_user_neg_count[negative_users])
return self._negative_table[negative_users, negative_item_choice] return self._negative_table[negative_users, negative_item_choice]
class BisectionDataConstructor(BaseDataConstructor):
"""Use bisection to index within positive examples.
This class tallies the number of negative items which appear before each
positive item for a user. This means that in order to select the ith negative
item for a user, it only needs to determine which two positive items bound
it at which point the item id for the ith negative is a simply algebraic
expression.
"""
def __init__(self, *args, **kwargs):
super(BisectionDataConstructor, self).__init__(*args, **kwargs)
self.index_bounds = None
self._sorted_train_pos_items = None
self._total_negatives = None
def _index_segment(self, user):
lower, upper = self.index_bounds[user:user+2]
items = self._sorted_train_pos_items[lower:upper]
negatives_since_last_positive = np.concatenate(
[items[0][np.newaxis], items[1:] - items[:-1] - 1])
return np.cumsum(negatives_since_last_positive)
def construct_lookup_variables(self):
start_time = timeit.default_timer()
inner_bounds = np.argwhere(self._train_pos_users[1:] -
self._train_pos_users[:-1])[:, 0] + 1
(upper_bound,) = self._train_pos_users.shape
self.index_bounds = np.array([0] + inner_bounds.tolist() + [upper_bound])
# Later logic will assume that the users are in sequential ascending order.
assert np.array_equal(self._train_pos_users[self.index_bounds[:-1]],
np.arange(self._num_users))
self._sorted_train_pos_items = self._train_pos_items.copy()
for i in range(self._num_users):
lower, upper = self.index_bounds[i:i+2]
self._sorted_train_pos_items[lower:upper].sort()
self._total_negatives = np.concatenate([
self._index_segment(i) for i in range(self._num_users)])
tf.logging.info("Negative total vector built. Time: {:.1f} seconds".format(
timeit.default_timer() - start_time))
def lookup_negative_items(self, negative_users, **kwargs):
output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
left_index = self.index_bounds[negative_users]
right_index = self.index_bounds[negative_users + 1] - 1
num_positives = right_index - left_index + 1
num_negatives = self._num_items - num_positives
neg_item_choice = stat_utils.very_slightly_biased_randint(num_negatives)
# Shortcuts:
# For points where the negative is greater than or equal to the tally before
# the last positive point there is no need to bisect. Instead the item id
# corresponding to the negative item choice is simply:
# last_postive_index + 1 + (neg_choice - last_negative_tally)
# Similarly, if the selection is less than the tally at the first positive
# then the item_id is simply the selection.
#
# Because MovieLens organizes popular movies into low integers (which is
# preserved through the preprocessing), the first shortcut is very
# efficient, allowing ~60% of samples to bypass the bisection. For the same
# reason, the second shortcut is rarely triggered (<0.02%) and is therefore
# not worth implementing.
use_shortcut = neg_item_choice >= self._total_negatives[right_index]
output[use_shortcut] = (
self._sorted_train_pos_items[right_index] + 1 +
(neg_item_choice - self._total_negatives[right_index])
)[use_shortcut]
not_use_shortcut = np.logical_not(use_shortcut)
left_index = left_index[not_use_shortcut]
right_index = right_index[not_use_shortcut]
neg_item_choice = neg_item_choice[not_use_shortcut]
num_loops = np.max(
np.ceil(np.log2(num_positives[not_use_shortcut])).astype(np.int32))
for i in range(num_loops):
mid_index = (left_index + right_index) // 2
right_criteria = self._total_negatives[mid_index] > neg_item_choice
left_criteria = np.logical_not(right_criteria)
right_index[right_criteria] = mid_index[right_criteria]
left_index[left_criteria] = mid_index[left_criteria]
# Expected state after bisection pass:
# The right index is the smallest index whose tally is greater than the
# negative item choice index.
assert np.all((right_index - left_index) <= 1)
output[not_use_shortcut] = (
self._sorted_train_pos_items[right_index] -
(self._total_negatives[right_index] - neg_item_choice)
)
assert np.all(output >= 0)
return output
def get_constructor(name):
if name == "bisection":
return BisectionDataConstructor
if name == "materialized":
return MaterializedDataConstructor
raise ValueError("Unrecognized constructor: {}".format(name))
...@@ -197,14 +197,18 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf): ...@@ -197,14 +197,18 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
return data, valid_cache return data, valid_cache
def instantiate_pipeline(dataset, data_dir, params): def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
# type: (str, str, dict) -> (NCFDataset, typing.Callable) deterministic=False):
# type: (str, str, dict, typing.Optional[str], bool) -> (NCFDataset, typing.Callable)
"""Load and digest data CSV into a usable form. """Load and digest data CSV into a usable form.
Args: Args:
dataset: The name of the dataset to be used. dataset: The name of the dataset to be used.
data_dir: The root directory of the dataset. data_dir: The root directory of the dataset.
params: dict of parameters for the run. params: dict of parameters for the run.
constructor_type: The name of the constructor subclass that should be used
for the input pipeline.
deterministic: Tell the data constructor to produce deterministically.
""" """
tf.logging.info("Beginning data preprocessing.") tf.logging.info("Beginning data preprocessing.")
...@@ -224,7 +228,7 @@ def instantiate_pipeline(dataset, data_dir, params): ...@@ -224,7 +228,7 @@ def instantiate_pipeline(dataset, data_dir, params):
raise ValueError("Expected to find {} items, but found {}".format( raise ValueError("Expected to find {} items, but found {}".format(
num_items, len(item_map))) num_items, len(item_map)))
producer = data_pipeline.MaterializedDataConstructor( producer = data_pipeline.get_constructor(constructor_type or "materialized")(
maximum_number_epochs=params["train_epochs"], maximum_number_epochs=params["train_epochs"],
num_users=num_users, num_users=num_users,
num_items=num_items, num_items=num_items,
...@@ -239,7 +243,8 @@ def instantiate_pipeline(dataset, data_dir, params): ...@@ -239,7 +243,8 @@ def instantiate_pipeline(dataset, data_dir, params):
eval_pos_items=raw_data[rconst.EVAL_ITEM_KEY], eval_pos_items=raw_data[rconst.EVAL_ITEM_KEY],
eval_batch_size=params["eval_batch_size"], eval_batch_size=params["eval_batch_size"],
batches_per_eval_step=params["batches_per_step"], batches_per_eval_step=params["batches_per_step"],
stream_files=params["use_tpu"] stream_files=params["use_tpu"],
deterministic=deterministic
) )
run_time = timeit.default_timer() - st run_time = timeit.default_timer() - st
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import defaultdict from collections import defaultdict
import hashlib
import os import os
import pickle import pickle
import time import time
...@@ -31,6 +32,7 @@ import tensorflow as tf ...@@ -31,6 +32,7 @@ import tensorflow as tf
from official.datasets import movielens from official.datasets import movielens
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.recommendation import popen_helper
from official.recommendation import stat_utils from official.recommendation import stat_utils
...@@ -43,12 +45,23 @@ EVAL_BATCH_SIZE = 4000 ...@@ -43,12 +45,23 @@ EVAL_BATCH_SIZE = 4000
NUM_NEG = 4 NUM_NEG = 4
END_TO_END_TRAIN_MD5 = "b218738e915e825d03939c5e305a2698"
END_TO_END_EVAL_MD5 = "d753d0f3186831466d6e218163a9501e"
FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22"
def mock_download(*args, **kwargs): def mock_download(*args, **kwargs):
return return
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
def setUp(self): def setUp(self):
# The forkpool used by data producers interacts badly with the threading
# used by TestCase. Without this monkey patch tests will hang, and no amount
# of diligent closing and joining within the producer will prevent it.
self._get_forkpool = popen_helper.get_forkpool
popen_helper.get_forkpool = popen_helper.get_fauxpool
self.temp_data_dir = self.get_temp_dir() self.temp_data_dir = self.get_temp_dir()
ratings_folder = os.path.join(self.temp_data_dir, DATASET) ratings_folder = os.path.join(self.temp_data_dir, DATASET)
tf.gfile.MakeDirs(ratings_folder) tf.gfile.MakeDirs(ratings_folder)
...@@ -86,6 +99,9 @@ class BaseTest(tf.test.TestCase): ...@@ -86,6 +99,9 @@ class BaseTest(tf.test.TestCase):
data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS, data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS,
NUM_ITEMS) NUM_ITEMS)
def tearDown(self):
popen_helper.get_forkpool = self._get_forkpool
def make_params(self, train_epochs=1): def make_params(self, train_epochs=1):
return { return {
"train_epochs": train_epochs, "train_epochs": train_epochs,
...@@ -126,10 +142,11 @@ class BaseTest(tf.test.TestCase): ...@@ -126,10 +142,11 @@ class BaseTest(tf.test.TestCase):
break break
return output return output
def test_end_to_end(self): def _test_end_to_end(self, constructor_type):
params = self.make_params(train_epochs=1) params = self.make_params(train_epochs=1)
_, _, producer = data_preprocessing.instantiate_pipeline( _, _, producer = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, params=params) dataset=DATASET, data_dir=self.temp_data_dir, params=params,
constructor_type=constructor_type, deterministic=True)
producer.start() producer.start()
producer.join() producer.join()
...@@ -154,10 +171,13 @@ class BaseTest(tf.test.TestCase): ...@@ -154,10 +171,13 @@ class BaseTest(tf.test.TestCase):
False: set(), False: set(),
} }
md5 = hashlib.md5()
for features, labels in first_epoch: for features, labels in first_epoch:
for u, i, v, l in zip( data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN],
features[rconst.VALID_POINT_MASK], labels): features[rconst.VALID_POINT_MASK], labels]
[md5.update(i.tobytes()) for i in data_list]
for u, i, v, l in zip(*data_list):
if not v: if not v:
continue # ignore padding continue # ignore padding
...@@ -172,8 +192,9 @@ class BaseTest(tf.test.TestCase): ...@@ -172,8 +192,9 @@ class BaseTest(tf.test.TestCase):
train_examples[l].add((u_raw, i_raw)) train_examples[l].add((u_raw, i_raw))
counts[(u_raw, i_raw)] += 1 counts[(u_raw, i_raw)] += 1
num_positives_seen = len(train_examples[True]) self.assertRegexpMatches(md5.hexdigest(), END_TO_END_TRAIN_MD5)
num_positives_seen = len(train_examples[True])
self.assertEqual(producer._train_pos_users.shape[0], num_positives_seen) self.assertEqual(producer._train_pos_users.shape[0], num_positives_seen)
# This check is more heuristic because negatives are sampled with # This check is more heuristic because negatives are sampled with
...@@ -196,10 +217,13 @@ class BaseTest(tf.test.TestCase): ...@@ -196,10 +217,13 @@ class BaseTest(tf.test.TestCase):
eval_data = self.drain_dataset(dataset=dataset, g=g) eval_data = self.drain_dataset(dataset=dataset, g=g)
current_user = None current_user = None
md5 = hashlib.md5()
for features in eval_data: for features in eval_data:
for idx, (u, i, d) in enumerate(zip(features[movielens.USER_COLUMN], data_list = [
features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN],
features[rconst.DUPLICATE_MASK])): features[rconst.DUPLICATE_MASK]]
[md5.update(i.tobytes()) for i in data_list]
for idx, (u, i, d) in enumerate(zip(*data_list)):
u_raw = user_inv_map[u] u_raw = user_inv_map[u]
i_raw = item_inv_map[i] i_raw = item_inv_map[i]
if current_user is None: if current_user is None:
...@@ -228,11 +252,14 @@ class BaseTest(tf.test.TestCase): ...@@ -228,11 +252,14 @@ class BaseTest(tf.test.TestCase):
# from the negatives. # from the negatives.
assert (u_raw, i_raw) not in self.seen_pairs assert (u_raw, i_raw) not in self.seen_pairs
def test_fresh_randomness(self): self.assertRegexpMatches(md5.hexdigest(), END_TO_END_EVAL_MD5)
def _test_fresh_randomness(self, constructor_type):
train_epochs = 5 train_epochs = 5
params = self.make_params(train_epochs=train_epochs) params = self.make_params(train_epochs=train_epochs)
_, _, producer = data_preprocessing.instantiate_pipeline( _, _, producer = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, params=params) dataset=DATASET, data_dir=self.temp_data_dir, params=params,
constructor_type=constructor_type, deterministic=True)
producer.start() producer.start()
...@@ -248,10 +275,13 @@ class BaseTest(tf.test.TestCase): ...@@ -248,10 +275,13 @@ class BaseTest(tf.test.TestCase):
assert producer._fatal_exception is None assert producer._fatal_exception is None
positive_counts, negative_counts = defaultdict(int), defaultdict(int) positive_counts, negative_counts = defaultdict(int), defaultdict(int)
md5 = hashlib.md5()
for features, labels in results: for features, labels in results:
for u, i, v, l in zip( data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN],
features[rconst.VALID_POINT_MASK], labels): features[rconst.VALID_POINT_MASK], labels]
[md5.update(i.tobytes()) for i in data_list]
for u, i, v, l in zip(*data_list):
if not v: if not v:
continue # ignore padding continue # ignore padding
...@@ -260,6 +290,8 @@ class BaseTest(tf.test.TestCase): ...@@ -260,6 +290,8 @@ class BaseTest(tf.test.TestCase):
else: else:
negative_counts[(u, i)] += 1 negative_counts[(u, i)] += 1
self.assertRegexpMatches(md5.hexdigest(), FRESH_RANDOMNESS_MD5)
# The positive examples should appear exactly once each epoch # The positive examples should appear exactly once each epoch
self.assertAllEqual(list(positive_counts.values()), self.assertAllEqual(list(positive_counts.values()),
[train_epochs for _ in positive_counts]) [train_epochs for _ in positive_counts])
...@@ -301,6 +333,18 @@ class BaseTest(tf.test.TestCase): ...@@ -301,6 +333,18 @@ class BaseTest(tf.test.TestCase):
self.assertLess(deviation, 0.2) self.assertLess(deviation, 0.2)
def test_end_to_end_materialized(self):
self._test_end_to_end("materialized")
def test_end_to_end_bisection(self):
self._test_end_to_end("bisection")
def test_fresh_randomness_materialized(self):
self._test_fresh_randomness("materialized")
def test_fresh_randomness_bisection(self):
self._test_fresh_randomness("bisection")
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
......
...@@ -192,8 +192,6 @@ def run_ncf(_): ...@@ -192,8 +192,6 @@ def run_ncf(_):
if FLAGS.seed is not None: if FLAGS.seed is not None:
np.random.seed(FLAGS.seed) np.random.seed(FLAGS.seed)
tf.logging.warning("Values may still vary from run to run due to thread "
"execution ordering.")
params = parse_flags(FLAGS) params = parse_flags(FLAGS)
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
...@@ -206,7 +204,9 @@ def run_ncf(_): ...@@ -206,7 +204,9 @@ def run_ncf(_):
num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
else: else:
num_users, num_items, producer = data_preprocessing.instantiate_pipeline( num_users, num_items, producer = data_preprocessing.instantiate_pipeline(
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params) dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params,
constructor_type=FLAGS.constructor_type,
deterministic=FLAGS.seed is not None)
num_train_steps = (producer.train_batches_per_epoch // num_train_steps = (producer.train_batches_per_epoch //
params["batches_per_step"]) params["batches_per_step"])
...@@ -383,6 +383,14 @@ def define_ncf_flags(): ...@@ -383,6 +383,14 @@ def define_ncf_flags():
"For dataset ml-20m, the threshold can be set as 0.95 which is " "For dataset ml-20m, the threshold can be set as 0.95 which is "
"achieved by MLPerf implementation.")) "achieved by MLPerf implementation."))
flags.DEFINE_enum(
name="constructor_type", default="bisection",
enum_values=["bisection", "materialized"], case_sensitive=False,
help=flags_core.help_wrap(
"Strategy to use for generating false negatives. materialized has a"
"precompute that scales badly, but a faster per-epoch construction"
"time and can be faster on very large systems."))
flags.DEFINE_bool( flags.DEFINE_bool(
name="ml_perf", default=False, name="ml_perf", default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
...@@ -414,13 +422,6 @@ def define_ncf_flags(): ...@@ -414,13 +422,6 @@ def define_ncf_flags():
name="seed", default=None, help=flags_core.help_wrap( name="seed", default=None, help=flags_core.help_wrap(
"This value will be used to seed both NumPy and TensorFlow.")) "This value will be used to seed both NumPy and TensorFlow."))
flags.DEFINE_bool(
name="hash_pipeline", default=False, help=flags_core.help_wrap(
"This flag will perform a separate run of the pipeline and hash "
"batches as they are produced. \nNOTE: this will significantly slow "
"training. However it is useful to confirm that a random seed is "
"does indeed make the data pipeline deterministic."))
@flags.validator("eval_batch_size", "eval_batch_size must be at least {}" @flags.validator("eval_batch_size", "eval_batch_size must be at least {}"
.format(rconst.NUM_EVAL_NEGATIVES + 1)) .format(rconst.NUM_EVAL_NEGATIVES + 1))
def eval_size_check(eval_batch_size): def eval_size_check(eval_batch_size):
......
...@@ -28,3 +28,33 @@ def get_threadpool(num_workers, init_worker=None, closing=True): ...@@ -28,3 +28,33 @@ def get_threadpool(num_workers, init_worker=None, closing=True):
pool = multiprocessing.pool.ThreadPool(processes=num_workers, pool = multiprocessing.pool.ThreadPool(processes=num_workers,
initializer=init_worker) initializer=init_worker)
return contextlib.closing(pool) if closing else pool return contextlib.closing(pool) if closing else pool
class FauxPool(object):
"""Mimic a pool using for loops.
This class is used in place of proper pools when true determinism is desired
for testing or debugging.
"""
def __init__(self, *args, **kwargs):
pass
def map(self, func, iterable, chunksize=None):
return [func(i) for i in iterable]
def imap(self, func, iterable, chunksize=1):
for i in iterable:
yield func(i)
def close(self):
pass
def terminate(self):
pass
def join(self):
pass
def get_fauxpool(num_workers, init_worker=None, closing=True):
pool = FauxPool(processes=num_workers, initializer=init_worker)
return contextlib.closing(pool) if closing else pool
...@@ -146,6 +146,10 @@ no-space-check= ...@@ -146,6 +146,10 @@ no-space-check=
# else. # else.
single-line-if-stmt=yes single-line-if-stmt=yes
# Allow URLs and comment type annotations to exceed the max line length as neither can be easily
# split across lines.
ignore-long-lines=^\s*(?:(# )?<?https?://\S+>?$|# type:)
[VARIABLES] [VARIABLES]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment