Commit ec0d43ba authored by Taylor Robie's avatar Taylor Robie
Browse files

address PR comments

parent c556dad9
......@@ -19,17 +19,15 @@ from __future__ import division
from __future__ import print_function
import atexit
import collections
import functools
import os
import pickle
import struct
import sys
import tempfile
import threading
import time
import timeit
import traceback
import typing
import numpy as np
import six
......@@ -82,6 +80,18 @@ class DatasetManager(object):
"""
def __init__(self, is_training, stream_files, batches_per_epoch,
shard_root=None):
# type: (bool, bool, int, typing.Optional[str]) -> None
"""Constructs a `DatasetManager` instance.
Args:
is_training: Boolean of whether the data provided is training or
evaluation data. This determines whether to reuse the data
(if is_training=False) and the exact structure to use when storing and
yielding data.
stream_files: Boolean indicating whether data should be serialized and
written to file shards.
batches_per_epoch: The number of batches in a single epoch.
shard_root: The base directory to be used when stream_files=True.
"""
self._is_training = is_training
self._stream_files = stream_files
self._writers = []
......@@ -183,9 +193,8 @@ class DatasetManager(object):
batch_size = data[movielens.ITEM_COLUMN].shape[0]
data[rconst.VALID_POINT_MASK] = np.less(np.arange(batch_size),
mask_start_index)
self._result_queue.put((data, data.pop("labels")))
else:
self._result_reuse.append(data)
data = (data, data.pop("labels"))
self._result_queue.put(data)
def start_construction(self):
if self._stream_files:
......@@ -199,26 +208,31 @@ class DatasetManager(object):
[writer.close() for writer in self._writers]
self._writers = []
self._result_queue.put(self.current_data_root)
elif not self._is_training:
self._result_queue.put(True) # data is ready.
self._epochs_completed += 1
def data_generator(self, epochs_between_evals):
"""Yields examples during local training."""
assert not self._stream_files
assert self._is_training or epochs_between_evals == 1
if self._is_training:
for _ in range(self._batches_per_epoch * epochs_between_evals):
yield self._result_queue.get(timeout=300)
else:
# Evaluation waits for all data to be ready.
self._result_queue.put(self._result_queue.get(timeout=300))
if self._result_reuse:
assert len(self._result_reuse) == self._batches_per_epoch
assert epochs_between_evals == 1
for i in self._result_reuse:
yield i
else:
# First epoch.
for _ in range(self._batches_per_epoch * epochs_between_evals):
result = self._result_queue.get(timeout=300)
self._result_reuse.append(result)
yield result
def get_dataset(self, batch_size, epochs_between_evals):
"""Construct the dataset to be used for training and eval.
......@@ -341,7 +355,7 @@ class BaseDataConstructor(threading.Thread):
"User positives ({}) is different from item positives ({})".format(
self._train_pos_users.shape, self._train_pos_items.shape))
self._train_pos_count = self._train_pos_users.shape[0]
(self._train_pos_count,) = self._train_pos_users.shape
self._elements_in_epoch = (1 + num_train_negatives) * self._train_pos_count
self.train_batches_per_epoch = self._count_batches(
self._elements_in_epoch, train_batch_size, batches_per_train_step)
......@@ -372,13 +386,12 @@ class BaseDataConstructor(threading.Thread):
False, stream_files, self.eval_batches_per_epoch, self._shard_root)
# Threading details
self._current_epoch_order_lock = threading.RLock()
super(BaseDataConstructor, self).__init__()
self.daemon = True
self._stop_loop = False
self._fatal_exception = None
def __repr__(self):
def __str__(self):
multiplier = ("(x{} devices)".format(self._batches_per_train_step)
if self._batches_per_train_step > 1 else "")
summary = SUMMARY_TEMPLATE.format(
......@@ -388,24 +401,17 @@ class BaseDataConstructor(threading.Thread):
train_batch_ct=self.train_batches_per_epoch,
eval_pos_ct=self._num_users, eval_batch_size=self.eval_batch_size,
eval_batch_ct=self.eval_batches_per_epoch, multiplier=multiplier)
return super(BaseDataConstructor, self).__repr__() + "\n" + summary
return super(BaseDataConstructor, self).__str__() + "\n" + summary
@staticmethod
def _count_batches(example_count, batch_size, batches_per_step):
"""Determine the number of batches, rounding up to fill all devices."""
x = (example_count + batch_size - 1) // batch_size
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
def stop_loop(self):
self._stop_loop = True
def _get_order_chunk(self):
with self._current_epoch_order_lock:
batch_indices, self._current_epoch_order = (
self._current_epoch_order[:self.train_batch_size],
self._current_epoch_order[self.train_batch_size:])
return batch_indices
def construct_lookup_variables(self):
"""Perform any one time pre-compute work."""
raise NotImplementedError
......@@ -429,7 +435,7 @@ class BaseDataConstructor(threading.Thread):
except Exception as e:
# The Thread base class swallows stack traces, so unfortunately it is
# necessary to catch and re-raise to get debug output
print(traceback.format_exc(), file=sys.stderr)
traceback.print_exc()
self._fatal_exception = e
sys.stderr.flush()
raise
......@@ -448,8 +454,9 @@ class BaseDataConstructor(threading.Thread):
i: The index of the batch. This is used when stream_files=True to assign
data to file shards.
"""
batch_indices = self._get_order_chunk()
mask_start_index = batch_indices.shape[0]
batch_indices = self._current_epoch_order[i * self.train_batch_size:
(i + 1) * self.train_batch_size]
(mask_start_index,) = batch_indices.shape
batch_ind_mod = np.mod(batch_indices, self._train_pos_count)
users = self._train_pos_users[batch_ind_mod]
......@@ -462,7 +469,7 @@ class BaseDataConstructor(threading.Thread):
items = self._train_pos_items[batch_ind_mod]
items[negative_indices] = negative_items
labels = np.logical_not(negative_indices).astype(np.bool)
labels = np.logical_not(negative_indices)
# Pad last partial batch
pad_length = self.train_batch_size - mask_start_index
......@@ -502,8 +509,7 @@ class BaseDataConstructor(threading.Thread):
self._train_dataset.start_construction()
map_args = list(range(self.train_batches_per_epoch))
assert not self._current_epoch_order.shape[0]
self._current_epoch_order = six.next(self._shuffle_iterator)
self._current_epoch_order = next(self._shuffle_iterator)
with popen_helper.get_threadpool(6) as pool:
pool.map(self._get_training_batch, map_args)
......@@ -536,7 +542,7 @@ class BaseDataConstructor(threading.Thread):
items = np.concatenate([positive_items, negative_items], axis=1)
# We pad the users and items here so that the duplicate mask calculation
# will include the padding. The metric function relies on every element
# will include padding. The metric function relies on all padded elements
# except the positive being marked as duplicate to mask out padded points.
if users.shape[0] < users_per_batch:
pad_rows = users_per_batch - users.shape[0]
......@@ -592,6 +598,8 @@ class BaseDataConstructor(threading.Thread):
timeit.default_timer() - start_time))
def make_input_fn(self, is_training):
# It isn't feasible to provide a foolproof check, so this is designed to
# catch most failures rather than provide an exhaustive guard.
if self._fatal_exception is not None:
raise ValueError("Fatal exception in the data production loop: {}"
.format(self._fatal_exception))
......@@ -616,7 +624,7 @@ class DummyConstructor(threading.Thread):
def input_fn(params):
"""Generated input_fn for the given epoch."""
batch_size = (params["batch_size"] if is_training else
params["eval_batch_size"] or params["batch_size"])
params["eval_batch_size"])
num_users = params["num_users"]
num_items = params["num_items"]
......@@ -657,7 +665,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
This class creates a table (num_users x num_items) containing all of the
negative examples for each user. This table is conceptually ragged; that is to
say the items dimension will have elements at the end which are not used equal
say the items dimension will have a number of unused elements at the end equal
to the number of positive elements for a given user. For instance:
num_users = 3
......@@ -693,7 +701,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
start_time = timeit.default_timer()
inner_bounds = np.argwhere(self._train_pos_users[1:] -
self._train_pos_users[:-1])[:, 0] + 1
upper_bound = self._train_pos_users.shape[0]
(upper_bound,) = self._train_pos_users.shape
index_bounds = [0] + inner_bounds.tolist() + [upper_bound]
self._negative_table = np.zeros(shape=(self._num_users, self._num_items),
dtype=rconst.ITEM_DTYPE)
......
......@@ -114,7 +114,7 @@ def construct_estimator(model_dir, params):
def log_and_get_hooks(eval_batch_size):
"""Convenience method for hook and logger creation."""
"""Convenience function for hook and logger creation."""
# Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks,
......@@ -140,19 +140,16 @@ def log_and_get_hooks(eval_batch_size):
def parse_flags(flags_obj):
"""Convenience method to turn flags into params."""
"""Convenience function to turn flags into params."""
num_gpus = flags_core.get_num_gpus(flags_obj)
num_devices = FLAGS.num_tpu_shards if FLAGS.tpu else num_gpus or 1
batch_size = distribution_utils.per_device_batch_size(
(int(flags_obj.batch_size) + num_devices - 1) //
num_devices * num_devices, num_devices)
batch_size = (flags_obj.batch_size + num_devices - 1) // num_devices
eval_divisor = (rconst.NUM_EVAL_NEGATIVES + 1) * num_devices
eval_batch_size = int(flags_obj.eval_batch_size or flags_obj.batch_size or 1)
eval_batch_size = distribution_utils.per_device_batch_size(
(eval_batch_size + eval_divisor - 1) //
eval_divisor * eval_divisor, num_devices)
eval_batch_size = flags_obj.eval_batch_size or flags_obj.batch_size
eval_batch_size = ((eval_batch_size + eval_divisor - 1) //
eval_divisor * eval_divisor // num_devices)
return {
"train_epochs": flags_obj.train_epochs,
......
......@@ -18,27 +18,32 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import atexit
from collections import deque
import multiprocessing
import os
import struct
import sys
import threading
import time
import numpy as np
from official.recommendation import popen_helper
def random_int32():
return np.random.randint(low=0, high=np.iinfo(np.int32).max, dtype=np.int32)
def permutation(args):
"""Fork safe permutation function.
This function can be called within a multiprocessing worker and give
appropriately random results.
Args:
args: A size two tuple that will unpacked into the size of the permutation
and the random seed. This form is used because starmap is not universally
available.
returns:
A NumPy array containing a random permutation.
"""
x, seed = args
seed = seed or struct.unpack("<L", os.urandom(4))[0]
# If seed is None NumPy will seed randomly.
state = np.random.RandomState(seed=seed) # pylint: disable=no-member
output = np.arange(x, dtype=np.int32)
state.shuffle(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment