"src/graph/vscode:/vscode.git/clone" did not exist on "18eaad17cce0ccb358df343cd8b1479582a2712c"
Commit ec0d43ba authored by Taylor Robie's avatar Taylor Robie
Browse files

address PR comments

parent c556dad9
...@@ -19,17 +19,15 @@ from __future__ import division ...@@ -19,17 +19,15 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import atexit import atexit
import collections
import functools import functools
import os import os
import pickle
import struct
import sys import sys
import tempfile import tempfile
import threading import threading
import time import time
import timeit import timeit
import traceback import traceback
import typing
import numpy as np import numpy as np
import six import six
...@@ -82,6 +80,18 @@ class DatasetManager(object): ...@@ -82,6 +80,18 @@ class DatasetManager(object):
""" """
def __init__(self, is_training, stream_files, batches_per_epoch, def __init__(self, is_training, stream_files, batches_per_epoch,
shard_root=None): shard_root=None):
# type: (bool, bool, int, typing.Optional[str]) -> None
"""Constructs a `DatasetManager` instance.
Args:
is_training: Boolean of whether the data provided is training or
evaluation data. This determines whether to reuse the data
(if is_training=False) and the exact structure to use when storing and
yielding data.
stream_files: Boolean indicating whether data should be serialized and
written to file shards.
batches_per_epoch: The number of batches in a single epoch.
shard_root: The base directory to be used when stream_files=True.
"""
self._is_training = is_training self._is_training = is_training
self._stream_files = stream_files self._stream_files = stream_files
self._writers = [] self._writers = []
...@@ -183,9 +193,8 @@ class DatasetManager(object): ...@@ -183,9 +193,8 @@ class DatasetManager(object):
batch_size = data[movielens.ITEM_COLUMN].shape[0] batch_size = data[movielens.ITEM_COLUMN].shape[0]
data[rconst.VALID_POINT_MASK] = np.less(np.arange(batch_size), data[rconst.VALID_POINT_MASK] = np.less(np.arange(batch_size),
mask_start_index) mask_start_index)
self._result_queue.put((data, data.pop("labels"))) data = (data, data.pop("labels"))
else: self._result_queue.put(data)
self._result_reuse.append(data)
def start_construction(self): def start_construction(self):
if self._stream_files: if self._stream_files:
...@@ -199,26 +208,31 @@ class DatasetManager(object): ...@@ -199,26 +208,31 @@ class DatasetManager(object):
[writer.close() for writer in self._writers] [writer.close() for writer in self._writers]
self._writers = [] self._writers = []
self._result_queue.put(self.current_data_root) self._result_queue.put(self.current_data_root)
elif not self._is_training:
self._result_queue.put(True) # data is ready.
self._epochs_completed += 1 self._epochs_completed += 1
def data_generator(self, epochs_between_evals): def data_generator(self, epochs_between_evals):
"""Yields examples during local training.""" """Yields examples during local training."""
assert not self._stream_files assert not self._stream_files
assert self._is_training or epochs_between_evals == 1
if self._is_training: if self._is_training:
for _ in range(self._batches_per_epoch * epochs_between_evals): for _ in range(self._batches_per_epoch * epochs_between_evals):
yield self._result_queue.get(timeout=300) yield self._result_queue.get(timeout=300)
else: else:
# Evaluation waits for all data to be ready. if self._result_reuse:
self._result_queue.put(self._result_queue.get(timeout=300)) assert len(self._result_reuse) == self._batches_per_epoch
assert len(self._result_reuse) == self._batches_per_epoch
assert epochs_between_evals == 1 for i in self._result_reuse:
for i in self._result_reuse: yield i
yield i else:
# First epoch.
for _ in range(self._batches_per_epoch * epochs_between_evals):
result = self._result_queue.get(timeout=300)
self._result_reuse.append(result)
yield result
def get_dataset(self, batch_size, epochs_between_evals): def get_dataset(self, batch_size, epochs_between_evals):
"""Construct the dataset to be used for training and eval. """Construct the dataset to be used for training and eval.
...@@ -341,7 +355,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -341,7 +355,7 @@ class BaseDataConstructor(threading.Thread):
"User positives ({}) is different from item positives ({})".format( "User positives ({}) is different from item positives ({})".format(
self._train_pos_users.shape, self._train_pos_items.shape)) self._train_pos_users.shape, self._train_pos_items.shape))
self._train_pos_count = self._train_pos_users.shape[0] (self._train_pos_count,) = self._train_pos_users.shape
self._elements_in_epoch = (1 + num_train_negatives) * self._train_pos_count self._elements_in_epoch = (1 + num_train_negatives) * self._train_pos_count
self.train_batches_per_epoch = self._count_batches( self.train_batches_per_epoch = self._count_batches(
self._elements_in_epoch, train_batch_size, batches_per_train_step) self._elements_in_epoch, train_batch_size, batches_per_train_step)
...@@ -372,13 +386,12 @@ class BaseDataConstructor(threading.Thread): ...@@ -372,13 +386,12 @@ class BaseDataConstructor(threading.Thread):
False, stream_files, self.eval_batches_per_epoch, self._shard_root) False, stream_files, self.eval_batches_per_epoch, self._shard_root)
# Threading details # Threading details
self._current_epoch_order_lock = threading.RLock()
super(BaseDataConstructor, self).__init__() super(BaseDataConstructor, self).__init__()
self.daemon = True self.daemon = True
self._stop_loop = False self._stop_loop = False
self._fatal_exception = None self._fatal_exception = None
def __repr__(self): def __str__(self):
multiplier = ("(x{} devices)".format(self._batches_per_train_step) multiplier = ("(x{} devices)".format(self._batches_per_train_step)
if self._batches_per_train_step > 1 else "") if self._batches_per_train_step > 1 else "")
summary = SUMMARY_TEMPLATE.format( summary = SUMMARY_TEMPLATE.format(
...@@ -388,24 +401,17 @@ class BaseDataConstructor(threading.Thread): ...@@ -388,24 +401,17 @@ class BaseDataConstructor(threading.Thread):
train_batch_ct=self.train_batches_per_epoch, train_batch_ct=self.train_batches_per_epoch,
eval_pos_ct=self._num_users, eval_batch_size=self.eval_batch_size, eval_pos_ct=self._num_users, eval_batch_size=self.eval_batch_size,
eval_batch_ct=self.eval_batches_per_epoch, multiplier=multiplier) eval_batch_ct=self.eval_batches_per_epoch, multiplier=multiplier)
return super(BaseDataConstructor, self).__repr__() + "\n" + summary return super(BaseDataConstructor, self).__str__() + "\n" + summary
@staticmethod @staticmethod
def _count_batches(example_count, batch_size, batches_per_step): def _count_batches(example_count, batch_size, batches_per_step):
"""Determine the number of batches, rounding up to fill all devices."""
x = (example_count + batch_size - 1) // batch_size x = (example_count + batch_size - 1) // batch_size
return (x + batches_per_step - 1) // batches_per_step * batches_per_step return (x + batches_per_step - 1) // batches_per_step * batches_per_step
def stop_loop(self): def stop_loop(self):
self._stop_loop = True self._stop_loop = True
def _get_order_chunk(self):
with self._current_epoch_order_lock:
batch_indices, self._current_epoch_order = (
self._current_epoch_order[:self.train_batch_size],
self._current_epoch_order[self.train_batch_size:])
return batch_indices
def construct_lookup_variables(self): def construct_lookup_variables(self):
"""Perform any one time pre-compute work.""" """Perform any one time pre-compute work."""
raise NotImplementedError raise NotImplementedError
...@@ -429,7 +435,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -429,7 +435,7 @@ class BaseDataConstructor(threading.Thread):
except Exception as e: except Exception as e:
# The Thread base class swallows stack traces, so unfortunately it is # The Thread base class swallows stack traces, so unfortunately it is
# necessary to catch and re-raise to get debug output # necessary to catch and re-raise to get debug output
print(traceback.format_exc(), file=sys.stderr) traceback.print_exc()
self._fatal_exception = e self._fatal_exception = e
sys.stderr.flush() sys.stderr.flush()
raise raise
...@@ -448,8 +454,9 @@ class BaseDataConstructor(threading.Thread): ...@@ -448,8 +454,9 @@ class BaseDataConstructor(threading.Thread):
i: The index of the batch. This is used when stream_files=True to assign i: The index of the batch. This is used when stream_files=True to assign
data to file shards. data to file shards.
""" """
batch_indices = self._get_order_chunk() batch_indices = self._current_epoch_order[i * self.train_batch_size:
mask_start_index = batch_indices.shape[0] (i + 1) * self.train_batch_size]
(mask_start_index,) = batch_indices.shape
batch_ind_mod = np.mod(batch_indices, self._train_pos_count) batch_ind_mod = np.mod(batch_indices, self._train_pos_count)
users = self._train_pos_users[batch_ind_mod] users = self._train_pos_users[batch_ind_mod]
...@@ -462,7 +469,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -462,7 +469,7 @@ class BaseDataConstructor(threading.Thread):
items = self._train_pos_items[batch_ind_mod] items = self._train_pos_items[batch_ind_mod]
items[negative_indices] = negative_items items[negative_indices] = negative_items
labels = np.logical_not(negative_indices).astype(np.bool) labels = np.logical_not(negative_indices)
# Pad last partial batch # Pad last partial batch
pad_length = self.train_batch_size - mask_start_index pad_length = self.train_batch_size - mask_start_index
...@@ -502,8 +509,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -502,8 +509,7 @@ class BaseDataConstructor(threading.Thread):
self._train_dataset.start_construction() self._train_dataset.start_construction()
map_args = list(range(self.train_batches_per_epoch)) map_args = list(range(self.train_batches_per_epoch))
assert not self._current_epoch_order.shape[0] self._current_epoch_order = next(self._shuffle_iterator)
self._current_epoch_order = six.next(self._shuffle_iterator)
with popen_helper.get_threadpool(6) as pool: with popen_helper.get_threadpool(6) as pool:
pool.map(self._get_training_batch, map_args) pool.map(self._get_training_batch, map_args)
...@@ -536,7 +542,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -536,7 +542,7 @@ class BaseDataConstructor(threading.Thread):
items = np.concatenate([positive_items, negative_items], axis=1) items = np.concatenate([positive_items, negative_items], axis=1)
# We pad the users and items here so that the duplicate mask calculation # We pad the users and items here so that the duplicate mask calculation
# will include the padding. The metric function relies on every element # will include padding. The metric function relies on all padded elements
# except the positive being marked as duplicate to mask out padded points. # except the positive being marked as duplicate to mask out padded points.
if users.shape[0] < users_per_batch: if users.shape[0] < users_per_batch:
pad_rows = users_per_batch - users.shape[0] pad_rows = users_per_batch - users.shape[0]
...@@ -592,6 +598,8 @@ class BaseDataConstructor(threading.Thread): ...@@ -592,6 +598,8 @@ class BaseDataConstructor(threading.Thread):
timeit.default_timer() - start_time)) timeit.default_timer() - start_time))
def make_input_fn(self, is_training): def make_input_fn(self, is_training):
# It isn't feasible to provide a foolproof check, so this is designed to
# catch most failures rather than provide an exhaustive guard.
if self._fatal_exception is not None: if self._fatal_exception is not None:
raise ValueError("Fatal exception in the data production loop: {}" raise ValueError("Fatal exception in the data production loop: {}"
.format(self._fatal_exception)) .format(self._fatal_exception))
...@@ -616,7 +624,7 @@ class DummyConstructor(threading.Thread): ...@@ -616,7 +624,7 @@ class DummyConstructor(threading.Thread):
def input_fn(params): def input_fn(params):
"""Generated input_fn for the given epoch.""" """Generated input_fn for the given epoch."""
batch_size = (params["batch_size"] if is_training else batch_size = (params["batch_size"] if is_training else
params["eval_batch_size"] or params["batch_size"]) params["eval_batch_size"])
num_users = params["num_users"] num_users = params["num_users"]
num_items = params["num_items"] num_items = params["num_items"]
...@@ -657,7 +665,7 @@ class MaterializedDataConstructor(BaseDataConstructor): ...@@ -657,7 +665,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
This class creates a table (num_users x num_items) containing all of the This class creates a table (num_users x num_items) containing all of the
negative examples for each user. This table is conceptually ragged; that is to negative examples for each user. This table is conceptually ragged; that is to
say the items dimension will have elements at the end which are not used equal say the items dimension will have a number of unused elements at the end equal
to the number of positive elements for a given user. For instance: to the number of positive elements for a given user. For instance:
num_users = 3 num_users = 3
...@@ -693,7 +701,7 @@ class MaterializedDataConstructor(BaseDataConstructor): ...@@ -693,7 +701,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
start_time = timeit.default_timer() start_time = timeit.default_timer()
inner_bounds = np.argwhere(self._train_pos_users[1:] - inner_bounds = np.argwhere(self._train_pos_users[1:] -
self._train_pos_users[:-1])[:, 0] + 1 self._train_pos_users[:-1])[:, 0] + 1
upper_bound = self._train_pos_users.shape[0] (upper_bound,) = self._train_pos_users.shape
index_bounds = [0] + inner_bounds.tolist() + [upper_bound] index_bounds = [0] + inner_bounds.tolist() + [upper_bound]
self._negative_table = np.zeros(shape=(self._num_users, self._num_items), self._negative_table = np.zeros(shape=(self._num_users, self._num_items),
dtype=rconst.ITEM_DTYPE) dtype=rconst.ITEM_DTYPE)
......
...@@ -114,7 +114,7 @@ def construct_estimator(model_dir, params): ...@@ -114,7 +114,7 @@ def construct_estimator(model_dir, params):
def log_and_get_hooks(eval_batch_size): def log_and_get_hooks(eval_batch_size):
"""Convenience method for hook and logger creation.""" """Convenience function for hook and logger creation."""
# Create hooks that log information about the training and metric values # Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, FLAGS.hooks,
...@@ -140,19 +140,16 @@ def log_and_get_hooks(eval_batch_size): ...@@ -140,19 +140,16 @@ def log_and_get_hooks(eval_batch_size):
def parse_flags(flags_obj): def parse_flags(flags_obj):
"""Convenience method to turn flags into params.""" """Convenience function to turn flags into params."""
num_gpus = flags_core.get_num_gpus(flags_obj) num_gpus = flags_core.get_num_gpus(flags_obj)
num_devices = FLAGS.num_tpu_shards if FLAGS.tpu else num_gpus or 1 num_devices = FLAGS.num_tpu_shards if FLAGS.tpu else num_gpus or 1
batch_size = distribution_utils.per_device_batch_size( batch_size = (flags_obj.batch_size + num_devices - 1) // num_devices
(int(flags_obj.batch_size) + num_devices - 1) //
num_devices * num_devices, num_devices)
eval_divisor = (rconst.NUM_EVAL_NEGATIVES + 1) * num_devices eval_divisor = (rconst.NUM_EVAL_NEGATIVES + 1) * num_devices
eval_batch_size = int(flags_obj.eval_batch_size or flags_obj.batch_size or 1) eval_batch_size = flags_obj.eval_batch_size or flags_obj.batch_size
eval_batch_size = distribution_utils.per_device_batch_size( eval_batch_size = ((eval_batch_size + eval_divisor - 1) //
(eval_batch_size + eval_divisor - 1) // eval_divisor * eval_divisor // num_devices)
eval_divisor * eval_divisor, num_devices)
return { return {
"train_epochs": flags_obj.train_epochs, "train_epochs": flags_obj.train_epochs,
......
...@@ -18,27 +18,32 @@ from __future__ import absolute_import ...@@ -18,27 +18,32 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import atexit
from collections import deque
import multiprocessing
import os import os
import struct
import sys
import threading
import time
import numpy as np import numpy as np
from official.recommendation import popen_helper
def random_int32(): def random_int32():
return np.random.randint(low=0, high=np.iinfo(np.int32).max, dtype=np.int32) return np.random.randint(low=0, high=np.iinfo(np.int32).max, dtype=np.int32)
def permutation(args): def permutation(args):
"""Fork safe permutation function.
This function can be called within a multiprocessing worker and give
appropriately random results.
Args:
args: A size two tuple that will unpacked into the size of the permutation
and the random seed. This form is used because starmap is not universally
available.
returns:
A NumPy array containing a random permutation.
"""
x, seed = args x, seed = args
seed = seed or struct.unpack("<L", os.urandom(4))[0]
# If seed is None NumPy will seed randomly.
state = np.random.RandomState(seed=seed) # pylint: disable=no-member state = np.random.RandomState(seed=seed) # pylint: disable=no-member
output = np.arange(x, dtype=np.int32) output = np.arange(x, dtype=np.int32)
state.shuffle(output) state.shuffle(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment