Commit 999fae62 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 94561082
...@@ -41,7 +41,6 @@ from official.recommendation import popen_helper ...@@ -41,7 +41,6 @@ from official.recommendation import popen_helper
from official.recommendation import stat_utils from official.recommendation import stat_utils
from tensorflow.python.tpu.datasets import StreamingFilesDataset from tensorflow.python.tpu.datasets import StreamingFilesDataset
SUMMARY_TEMPLATE = """General: SUMMARY_TEMPLATE = """General:
{spacer}Num users: {num_users} {spacer}Num users: {num_users}
{spacer}Num items: {num_items} {spacer}Num items: {num_items}
...@@ -74,25 +73,27 @@ class DatasetManager(object): ...@@ -74,25 +73,27 @@ class DatasetManager(object):
num_train_epochs=None): num_train_epochs=None):
# type: (bool, bool, int, typing.Optional[str], bool, int) -> None # type: (bool, bool, int, typing.Optional[str], bool, int) -> None
"""Constructs a `DatasetManager` instance. """Constructs a `DatasetManager` instance.
Args: Args:
is_training: Boolean of whether the data provided is training or is_training: Boolean of whether the data provided is training or
evaluation data. This determines whether to reuse the data evaluation data. This determines whether to reuse the data (if
(if is_training=False) and the exact structure to use when storing and is_training=False) and the exact structure to use when storing and
yielding data. yielding data.
stream_files: Boolean indicating whether data should be serialized and stream_files: Boolean indicating whether data should be serialized and
written to file shards. written to file shards.
batches_per_epoch: The number of batches in a single epoch. batches_per_epoch: The number of batches in a single epoch.
shard_root: The base directory to be used when stream_files=True. shard_root: The base directory to be used when stream_files=True.
deterministic: Forgo non-deterministic speedups. (i.e. sloppy=True) deterministic: Forgo non-deterministic speedups. (i.e. sloppy=True)
num_train_epochs: Number of epochs to generate. If None, then each num_train_epochs: Number of epochs to generate. If None, then each call to
call to `get_dataset()` increments the number of epochs requested. `get_dataset()` increments the number of epochs requested.
""" """
self._is_training = is_training self._is_training = is_training
self._deterministic = deterministic self._deterministic = deterministic
self._stream_files = stream_files self._stream_files = stream_files
self._writers = [] self._writers = []
self._write_locks = [threading.RLock() for _ in self._write_locks = [
range(rconst.NUM_FILE_SHARDS)] if stream_files else [] threading.RLock() for _ in range(rconst.NUM_FILE_SHARDS)
] if stream_files else []
self._batches_per_epoch = batches_per_epoch self._batches_per_epoch = batches_per_epoch
self._epochs_completed = 0 self._epochs_completed = 0
self._epochs_requested = num_train_epochs if num_train_epochs else 0 self._epochs_requested = num_train_epochs if num_train_epochs else 0
...@@ -103,8 +104,9 @@ class DatasetManager(object): ...@@ -103,8 +104,9 @@ class DatasetManager(object):
@property @property
def current_data_root(self): def current_data_root(self):
subdir = (rconst.TRAIN_FOLDER_TEMPLATE.format(self._epochs_completed) subdir = (
if self._is_training else rconst.EVAL_FOLDER) rconst.TRAIN_FOLDER_TEMPLATE.format(self._epochs_completed)
if self._is_training else rconst.EVAL_FOLDER)
return os.path.join(self._shard_root, subdir) return os.path.join(self._shard_root, subdir)
def buffer_reached(self): def buffer_reached(self):
...@@ -123,8 +125,8 @@ class DatasetManager(object): ...@@ -123,8 +125,8 @@ class DatasetManager(object):
k: create_int_feature(v.astype(np.int64)) for k, v in data.items() k: create_int_feature(v.astype(np.int64)) for k, v in data.items()
} }
return tf.train.Example( return tf.train.Example(features=tf.train.Features(
features=tf.train.Features(feature=feature_dict)).SerializeToString() feature=feature_dict)).SerializeToString()
@staticmethod @staticmethod
def deserialize(serialized_data, batch_size=None, is_training=True): def deserialize(serialized_data, batch_size=None, is_training=True):
...@@ -134,8 +136,8 @@ class DatasetManager(object): ...@@ -134,8 +136,8 @@ class DatasetManager(object):
serialized_data: A tensor containing serialized records. serialized_data: A tensor containing serialized records.
batch_size: The data arrives pre-batched, so batch size is needed to batch_size: The data arrives pre-batched, so batch size is needed to
deserialize the data. deserialize the data.
is_training: Boolean, whether data to deserialize to training data is_training: Boolean, whether data to deserialize to training data or
or evaluation data. evaluation data.
""" """
def _get_feature_map(batch_size, is_training=True): def _get_feature_map(batch_size, is_training=True):
...@@ -171,13 +173,16 @@ class DatasetManager(object): ...@@ -171,13 +173,16 @@ class DatasetManager(object):
valid_point_mask = tf.cast(features[rconst.VALID_POINT_MASK], tf.bool) valid_point_mask = tf.cast(features[rconst.VALID_POINT_MASK], tf.bool)
fake_dup_mask = tf.zeros_like(users) fake_dup_mask = tf.zeros_like(users)
return { return {
movielens.USER_COLUMN: users, movielens.USER_COLUMN:
movielens.ITEM_COLUMN: items, users,
rconst.VALID_POINT_MASK: valid_point_mask, movielens.ITEM_COLUMN:
items,
rconst.VALID_POINT_MASK:
valid_point_mask,
rconst.TRAIN_LABEL_KEY: rconst.TRAIN_LABEL_KEY:
tf.reshape(tf.cast(features["labels"], tf.bool), tf.reshape(tf.cast(features["labels"], tf.bool), (batch_size, 1)),
(batch_size, 1)), rconst.DUPLICATE_MASK:
rconst.DUPLICATE_MASK: fake_dup_mask fake_dup_mask
} }
else: else:
labels = tf.cast(tf.zeros_like(users), tf.bool) labels = tf.cast(tf.zeros_like(users), tf.bool)
...@@ -228,8 +233,10 @@ class DatasetManager(object): ...@@ -228,8 +233,10 @@ class DatasetManager(object):
if self._stream_files: if self._stream_files:
tf.io.gfile.makedirs(self.current_data_root) tf.io.gfile.makedirs(self.current_data_root)
template = os.path.join(self.current_data_root, rconst.SHARD_TEMPLATE) template = os.path.join(self.current_data_root, rconst.SHARD_TEMPLATE)
self._writers = [tf.io.TFRecordWriter(template.format(i)) self._writers = [
for i in range(rconst.NUM_FILE_SHARDS)] tf.io.TFRecordWriter(template.format(i))
for i in range(rconst.NUM_FILE_SHARDS)
]
def end_construction(self): def end_construction(self):
if self._stream_files: if self._stream_files:
...@@ -273,8 +280,8 @@ class DatasetManager(object): ...@@ -273,8 +280,8 @@ class DatasetManager(object):
Args: Args:
batch_size: The per-replica batch size of the dataset. batch_size: The per-replica batch size of the dataset.
epochs_between_evals: How many epochs worth of data to yield. epochs_between_evals: How many epochs worth of data to yield. (Generator
(Generator mode only.) mode only.)
""" """
self.increment_request_epoch() self.increment_request_epoch()
if self._stream_files: if self._stream_files:
...@@ -285,11 +292,13 @@ class DatasetManager(object): ...@@ -285,11 +292,13 @@ class DatasetManager(object):
if not self._is_training: if not self._is_training:
self._result_queue.put(epoch_data_dir) # Eval data is reused. self._result_queue.put(epoch_data_dir) # Eval data is reused.
file_pattern = os.path.join( file_pattern = os.path.join(epoch_data_dir,
epoch_data_dir, rconst.SHARD_TEMPLATE.format("*")) rconst.SHARD_TEMPLATE.format("*"))
dataset = StreamingFilesDataset( dataset = StreamingFilesDataset(
files=file_pattern, worker_job=popen_helper.worker_job(), files=file_pattern,
num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1, worker_job=popen_helper.worker_job(),
num_parallel_reads=rconst.NUM_FILE_SHARDS,
num_epochs=1,
sloppy=not self._deterministic) sloppy=not self._deterministic)
map_fn = functools.partial( map_fn = functools.partial(
self.deserialize, self.deserialize,
...@@ -298,8 +307,10 @@ class DatasetManager(object): ...@@ -298,8 +307,10 @@ class DatasetManager(object):
dataset = dataset.map(map_fn, num_parallel_calls=16) dataset = dataset.map(map_fn, num_parallel_calls=16)
else: else:
types = {movielens.USER_COLUMN: rconst.USER_DTYPE, types = {
movielens.ITEM_COLUMN: rconst.ITEM_DTYPE} movielens.USER_COLUMN: rconst.USER_DTYPE,
movielens.ITEM_COLUMN: rconst.ITEM_DTYPE
}
shapes = { shapes = {
movielens.USER_COLUMN: tf.TensorShape([batch_size, 1]), movielens.USER_COLUMN: tf.TensorShape([batch_size, 1]),
movielens.ITEM_COLUMN: tf.TensorShape([batch_size, 1]) movielens.ITEM_COLUMN: tf.TensorShape([batch_size, 1])
...@@ -319,8 +330,7 @@ class DatasetManager(object): ...@@ -319,8 +330,7 @@ class DatasetManager(object):
data_generator = functools.partial( data_generator = functools.partial(
self.data_generator, epochs_between_evals=epochs_between_evals) self.data_generator, epochs_between_evals=epochs_between_evals)
dataset = tf.data.Dataset.from_generator( dataset = tf.data.Dataset.from_generator(
generator=data_generator, output_types=types, generator=data_generator, output_types=types, output_shapes=shapes)
output_shapes=shapes)
return dataset.prefetch(16) return dataset.prefetch(16)
...@@ -332,16 +342,17 @@ class DatasetManager(object): ...@@ -332,16 +342,17 @@ class DatasetManager(object):
# Estimator passes batch_size during training and eval_batch_size during # Estimator passes batch_size during training and eval_batch_size during
# eval. # eval.
param_batch_size = (params["batch_size"] if self._is_training else param_batch_size = (
params.get("eval_batch_size") or params["batch_size"]) params["batch_size"] if self._is_training else
params.get("eval_batch_size") or params["batch_size"])
if batch_size != param_batch_size: if batch_size != param_batch_size:
raise ValueError("producer batch size ({}) differs from params batch " raise ValueError("producer batch size ({}) differs from params batch "
"size ({})".format(batch_size, param_batch_size)) "size ({})".format(batch_size, param_batch_size))
epochs_between_evals = (params.get("epochs_between_evals", 1) epochs_between_evals = (
if self._is_training else 1) params.get("epochs_between_evals", 1) if self._is_training else 1)
return self.get_dataset(batch_size=batch_size, return self.get_dataset(
epochs_between_evals=epochs_between_evals) batch_size=batch_size, epochs_between_evals=epochs_between_evals)
return input_fn return input_fn
...@@ -405,15 +416,16 @@ class BaseDataConstructor(threading.Thread): ...@@ -405,15 +416,16 @@ class BaseDataConstructor(threading.Thread):
(self._train_pos_count,) = self._train_pos_users.shape (self._train_pos_count,) = self._train_pos_users.shape
self._elements_in_epoch = (1 + num_train_negatives) * self._train_pos_count self._elements_in_epoch = (1 + num_train_negatives) * self._train_pos_count
self.train_batches_per_epoch = self._count_batches( self.train_batches_per_epoch = self._count_batches(self._elements_in_epoch,
self._elements_in_epoch, train_batch_size, batches_per_train_step) train_batch_size,
batches_per_train_step)
# Evaluation # Evaluation
if eval_batch_size % (1 + rconst.NUM_EVAL_NEGATIVES): if eval_batch_size % (1 + rconst.NUM_EVAL_NEGATIVES):
raise ValueError("Eval batch size {} is not divisible by {}".format( raise ValueError("Eval batch size {} is not divisible by {}".format(
eval_batch_size, 1 + rconst.NUM_EVAL_NEGATIVES)) eval_batch_size, 1 + rconst.NUM_EVAL_NEGATIVES))
self._eval_users_per_batch = int( self._eval_users_per_batch = int(eval_batch_size //
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) (1 + rconst.NUM_EVAL_NEGATIVES))
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES) self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
self.eval_batches_per_epoch = self._count_batches( self.eval_batches_per_epoch = self._count_batches(
self._eval_elements_in_epoch, eval_batch_size, batches_per_eval_step) self._eval_elements_in_epoch, eval_batch_size, batches_per_eval_step)
...@@ -450,12 +462,16 @@ class BaseDataConstructor(threading.Thread): ...@@ -450,12 +462,16 @@ class BaseDataConstructor(threading.Thread):
multiplier = ("(x{} devices)".format(self._batches_per_train_step) multiplier = ("(x{} devices)".format(self._batches_per_train_step)
if self._batches_per_train_step > 1 else "") if self._batches_per_train_step > 1 else "")
summary = SUMMARY_TEMPLATE.format( summary = SUMMARY_TEMPLATE.format(
spacer=" ", num_users=self._num_users, num_items=self._num_items, spacer=" ",
num_users=self._num_users,
num_items=self._num_items,
train_pos_ct=self._train_pos_count, train_pos_ct=self._train_pos_count,
train_batch_size=self.train_batch_size, train_batch_size=self.train_batch_size,
train_batch_ct=self.train_batches_per_epoch, train_batch_ct=self.train_batches_per_epoch,
eval_pos_ct=self._num_users, eval_batch_size=self.eval_batch_size, eval_pos_ct=self._num_users,
eval_batch_ct=self.eval_batches_per_epoch, multiplier=multiplier) eval_batch_size=self.eval_batch_size,
eval_batch_ct=self.eval_batches_per_epoch,
multiplier=multiplier)
return super(BaseDataConstructor, self).__str__() + "\n" + summary return super(BaseDataConstructor, self).__str__() + "\n" + summary
@staticmethod @staticmethod
...@@ -514,8 +530,9 @@ class BaseDataConstructor(threading.Thread): ...@@ -514,8 +530,9 @@ class BaseDataConstructor(threading.Thread):
i: The index of the batch. This is used when stream_files=True to assign i: The index of the batch. This is used when stream_files=True to assign
data to file shards. data to file shards.
""" """
batch_indices = self._current_epoch_order[i * self.train_batch_size: batch_indices = self._current_epoch_order[i *
(i + 1) * self.train_batch_size] self.train_batch_size:(i + 1) *
self.train_batch_size]
(mask_start_index,) = batch_indices.shape (mask_start_index,) = batch_indices.shape
batch_ind_mod = np.mod(batch_indices, self._train_pos_count) batch_ind_mod = np.mod(batch_indices, self._train_pos_count)
...@@ -578,8 +595,9 @@ class BaseDataConstructor(threading.Thread): ...@@ -578,8 +595,9 @@ class BaseDataConstructor(threading.Thread):
map_args = list(range(self.train_batches_per_epoch)) map_args = list(range(self.train_batches_per_epoch))
self._current_epoch_order = next(self._shuffle_iterator) self._current_epoch_order = next(self._shuffle_iterator)
get_pool = (popen_helper.get_fauxpool if self.deterministic else get_pool = (
popen_helper.get_threadpool) popen_helper.get_fauxpool
if self.deterministic else popen_helper.get_threadpool)
with get_pool(6) as pool: with get_pool(6) as pool:
pool.map(self._get_training_batch, map_args) pool.map(self._get_training_batch, map_args)
self._train_dataset.end_construction() self._train_dataset.end_construction()
...@@ -602,8 +620,8 @@ class BaseDataConstructor(threading.Thread): ...@@ -602,8 +620,8 @@ class BaseDataConstructor(threading.Thread):
users: An array of users in a batch. (should be identical along axis 1) users: An array of users in a batch. (should be identical along axis 1)
positive_items: An array (batch_size x 1) of positive item indices. positive_items: An array (batch_size x 1) of positive item indices.
negative_items: An array of negative item indices. negative_items: An array of negative item indices.
users_per_batch: How many users should be in the batch. This is passed users_per_batch: How many users should be in the batch. This is passed as
as an argument so that ncf_test.py can use this method. an argument so that ncf_test.py can use this method.
Returns: Returns:
User, item, and duplicate_mask arrays. User, item, and duplicate_mask arrays.
...@@ -635,11 +653,14 @@ class BaseDataConstructor(threading.Thread): ...@@ -635,11 +653,14 @@ class BaseDataConstructor(threading.Thread):
""" """
low_index = i * self._eval_users_per_batch low_index = i * self._eval_users_per_batch
high_index = (i + 1) * self._eval_users_per_batch high_index = (i + 1) * self._eval_users_per_batch
users = np.repeat(self._eval_pos_users[low_index:high_index, np.newaxis], users = np.repeat(
1 + rconst.NUM_EVAL_NEGATIVES, axis=1) self._eval_pos_users[low_index:high_index, np.newaxis],
1 + rconst.NUM_EVAL_NEGATIVES,
axis=1)
positive_items = self._eval_pos_items[low_index:high_index, np.newaxis] positive_items = self._eval_pos_items[low_index:high_index, np.newaxis]
negative_items = (self.lookup_negative_items(negative_users=users[:, :-1]) negative_items = (
.reshape(-1, rconst.NUM_EVAL_NEGATIVES)) self.lookup_negative_items(negative_users=users[:, :-1]).reshape(
-1, rconst.NUM_EVAL_NEGATIVES))
users, items, duplicate_mask = self._assemble_eval_batch( users, items, duplicate_mask = self._assemble_eval_batch(
users, positive_items, negative_items, self._eval_users_per_batch) users, positive_items, negative_items, self._eval_users_per_batch)
...@@ -664,8 +685,9 @@ class BaseDataConstructor(threading.Thread): ...@@ -664,8 +685,9 @@ class BaseDataConstructor(threading.Thread):
self._eval_dataset.start_construction() self._eval_dataset.start_construction()
map_args = [i for i in range(self.eval_batches_per_epoch)] map_args = [i for i in range(self.eval_batches_per_epoch)]
get_pool = (popen_helper.get_fauxpool if self.deterministic else get_pool = (
popen_helper.get_threadpool) popen_helper.get_fauxpool
if self.deterministic else popen_helper.get_threadpool)
with get_pool(6) as pool: with get_pool(6) as pool:
pool.map(self._get_eval_batch, map_args) pool.map(self._get_eval_batch, map_args)
self._eval_dataset.end_construction() self._eval_dataset.end_construction()
...@@ -677,12 +699,12 @@ class BaseDataConstructor(threading.Thread): ...@@ -677,12 +699,12 @@ class BaseDataConstructor(threading.Thread):
# It isn't feasible to provide a foolproof check, so this is designed to # It isn't feasible to provide a foolproof check, so this is designed to
# catch most failures rather than provide an exhaustive guard. # catch most failures rather than provide an exhaustive guard.
if self._fatal_exception is not None: if self._fatal_exception is not None:
raise ValueError("Fatal exception in the data production loop: {}" raise ValueError("Fatal exception in the data production loop: {}".format(
.format(self._fatal_exception)) self._fatal_exception))
return ( return (self._train_dataset.make_input_fn(self.train_batch_size)
self._train_dataset.make_input_fn(self.train_batch_size) if is_training if is_training else self._eval_dataset.make_input_fn(
else self._eval_dataset.make_input_fn(self.eval_batch_size)) self.eval_batch_size))
def increment_request_epoch(self): def increment_request_epoch(self):
self._train_dataset.increment_request_epoch() self._train_dataset.increment_request_epoch()
...@@ -714,8 +736,9 @@ class DummyConstructor(threading.Thread): ...@@ -714,8 +736,9 @@ class DummyConstructor(threading.Thread):
# Estimator passes batch_size during training and eval_batch_size during # Estimator passes batch_size during training and eval_batch_size during
# eval. # eval.
batch_size = (params["batch_size"] if is_training else batch_size = (
params.get("eval_batch_size") or params["batch_size"]) params["batch_size"] if is_training else
params.get("eval_batch_size") or params["batch_size"])
num_users = params["num_users"] num_users = params["num_users"]
num_items = params["num_items"] num_items = params["num_items"]
...@@ -795,6 +818,7 @@ class MaterializedDataConstructor(BaseDataConstructor): ...@@ -795,6 +818,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
a pre-compute which is quadratic in problem size will still fit in memory. A a pre-compute which is quadratic in problem size will still fit in memory. A
more scalable lookup method is in the works. more scalable lookup method is in the works.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(MaterializedDataConstructor, self).__init__(*args, **kwargs) super(MaterializedDataConstructor, self).__init__(*args, **kwargs)
self._negative_table = None self._negative_table = None
...@@ -807,8 +831,8 @@ class MaterializedDataConstructor(BaseDataConstructor): ...@@ -807,8 +831,8 @@ class MaterializedDataConstructor(BaseDataConstructor):
self._train_pos_users[:-1])[:, 0] + 1 self._train_pos_users[:-1])[:, 0] + 1
(upper_bound,) = self._train_pos_users.shape (upper_bound,) = self._train_pos_users.shape
index_bounds = [0] + inner_bounds.tolist() + [upper_bound] index_bounds = [0] + inner_bounds.tolist() + [upper_bound]
self._negative_table = np.zeros(shape=(self._num_users, self._num_items), self._negative_table = np.zeros(
dtype=rconst.ITEM_DTYPE) shape=(self._num_users, self._num_items), dtype=rconst.ITEM_DTYPE)
# Set the table to the max value to make sure the embedding lookup will fail # Set the table to the max value to make sure the embedding lookup will fail
# if we go out of bounds, rather than just overloading item zero. # if we go out of bounds, rather than just overloading item zero.
...@@ -825,7 +849,7 @@ class MaterializedDataConstructor(BaseDataConstructor): ...@@ -825,7 +849,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
# call does not parallelize well. Multiprocessing incurs too much # call does not parallelize well. Multiprocessing incurs too much
# serialization overhead to be worthwhile. # serialization overhead to be worthwhile.
for i in range(self._num_users): for i in range(self._num_users):
positives = self._train_pos_items[index_bounds[i]:index_bounds[i+1]] positives = self._train_pos_items[index_bounds[i]:index_bounds[i + 1]]
negatives = np.delete(full_set, positives) negatives = np.delete(full_set, positives)
self._per_user_neg_count[i] = self._num_items - positives.shape[0] self._per_user_neg_count[i] = self._num_items - positives.shape[0]
self._negative_table[i, :self._per_user_neg_count[i]] = negatives self._negative_table[i, :self._per_user_neg_count[i]] = negatives
...@@ -848,6 +872,7 @@ class BisectionDataConstructor(BaseDataConstructor): ...@@ -848,6 +872,7 @@ class BisectionDataConstructor(BaseDataConstructor):
it at which point the item id for the ith negative is a simply algebraic it at which point the item id for the ith negative is a simply algebraic
expression. expression.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BisectionDataConstructor, self).__init__(*args, **kwargs) super(BisectionDataConstructor, self).__init__(*args, **kwargs)
self.index_bounds = None self.index_bounds = None
...@@ -855,7 +880,7 @@ class BisectionDataConstructor(BaseDataConstructor): ...@@ -855,7 +880,7 @@ class BisectionDataConstructor(BaseDataConstructor):
self._total_negatives = None self._total_negatives = None
def _index_segment(self, user): def _index_segment(self, user):
lower, upper = self.index_bounds[user:user+2] lower, upper = self.index_bounds[user:user + 2]
items = self._sorted_train_pos_items[lower:upper] items = self._sorted_train_pos_items[lower:upper]
negatives_since_last_positive = np.concatenate( negatives_since_last_positive = np.concatenate(
...@@ -877,11 +902,11 @@ class BisectionDataConstructor(BaseDataConstructor): ...@@ -877,11 +902,11 @@ class BisectionDataConstructor(BaseDataConstructor):
self._sorted_train_pos_items = self._train_pos_items.copy() self._sorted_train_pos_items = self._train_pos_items.copy()
for i in range(self._num_users): for i in range(self._num_users):
lower, upper = self.index_bounds[i:i+2] lower, upper = self.index_bounds[i:i + 2]
self._sorted_train_pos_items[lower:upper].sort() self._sorted_train_pos_items[lower:upper].sort()
self._total_negatives = np.concatenate([ self._total_negatives = np.concatenate(
self._index_segment(i) for i in range(self._num_users)]) [self._index_segment(i) for i in range(self._num_users)])
logging.info("Negative total vector built. Time: {:.1f} seconds".format( logging.info("Negative total vector built. Time: {:.1f} seconds".format(
timeit.default_timer() - start_time)) timeit.default_timer() - start_time))
...@@ -912,8 +937,7 @@ class BisectionDataConstructor(BaseDataConstructor): ...@@ -912,8 +937,7 @@ class BisectionDataConstructor(BaseDataConstructor):
use_shortcut = neg_item_choice >= self._total_negatives[right_index] use_shortcut = neg_item_choice >= self._total_negatives[right_index]
output[use_shortcut] = ( output[use_shortcut] = (
self._sorted_train_pos_items[right_index] + 1 + self._sorted_train_pos_items[right_index] + 1 +
(neg_item_choice - self._total_negatives[right_index]) (neg_item_choice - self._total_negatives[right_index]))[use_shortcut]
)[use_shortcut]
if np.all(use_shortcut): if np.all(use_shortcut):
# The bisection code is ill-posed when there are no elements. # The bisection code is ill-posed when there are no elements.
...@@ -943,8 +967,7 @@ class BisectionDataConstructor(BaseDataConstructor): ...@@ -943,8 +967,7 @@ class BisectionDataConstructor(BaseDataConstructor):
output[not_use_shortcut] = ( output[not_use_shortcut] = (
self._sorted_train_pos_items[right_index] - self._sorted_train_pos_items[right_index] -
(self._total_negatives[right_index] - neg_item_choice) (self._total_negatives[right_index] - neg_item_choice))
)
assert np.all(output >= 0) assert np.all(output >= 0)
......
...@@ -25,6 +25,7 @@ import time ...@@ -25,6 +25,7 @@ import time
import timeit import timeit
# pylint: disable=wrong-import-order # pylint: disable=wrong-import-order
from absl import logging from absl import logging
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -37,10 +38,9 @@ from official.recommendation import constants as rconst ...@@ -37,10 +38,9 @@ from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import movielens from official.recommendation import movielens
_EXPECTED_CACHE_KEYS = (rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY,
_EXPECTED_CACHE_KEYS = ( rconst.EVAL_USER_KEY, rconst.EVAL_ITEM_KEY,
rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
def read_dataframe( def read_dataframe(
...@@ -178,17 +178,20 @@ def _filter_index_sort(raw_rating_path: Text, ...@@ -178,17 +178,20 @@ def _filter_index_sort(raw_rating_path: Text,
eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1]) eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
data = { data = {
rconst.TRAIN_USER_KEY: train_df[movielens.USER_COLUMN] rconst.TRAIN_USER_KEY:
.values.astype(rconst.USER_DTYPE), train_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
rconst.TRAIN_ITEM_KEY: train_df[movielens.ITEM_COLUMN] rconst.TRAIN_ITEM_KEY:
.values.astype(rconst.ITEM_DTYPE), train_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
rconst.EVAL_USER_KEY: eval_df[movielens.USER_COLUMN] rconst.EVAL_USER_KEY:
.values.astype(rconst.USER_DTYPE), eval_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
rconst.EVAL_ITEM_KEY: eval_df[movielens.ITEM_COLUMN] rconst.EVAL_ITEM_KEY:
.values.astype(rconst.ITEM_DTYPE), eval_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
rconst.USER_MAP: user_map, rconst.USER_MAP:
rconst.ITEM_MAP: item_map, user_map,
"create_time": time.time(), rconst.ITEM_MAP:
item_map,
"create_time":
time.time(),
} }
logging.info("Writing raw data cache.") logging.info("Writing raw data cache.")
...@@ -217,8 +220,8 @@ def instantiate_pipeline(dataset, ...@@ -217,8 +220,8 @@ def instantiate_pipeline(dataset,
for the input pipeline. for the input pipeline.
deterministic: Tell the data constructor to produce deterministically. deterministic: Tell the data constructor to produce deterministically.
epoch_dir: Directory in which to store the training epochs. epoch_dir: Directory in which to store the training epochs.
generate_data_offline: Boolean, whether current pipeline is done offline generate_data_offline: Boolean, whether current pipeline is done offline or
or while training. while training.
""" """
logging.info("Beginning data preprocessing.") logging.info("Beginning data preprocessing.")
...@@ -258,8 +261,8 @@ def instantiate_pipeline(dataset, ...@@ -258,8 +261,8 @@ def instantiate_pipeline(dataset,
create_data_offline=generate_data_offline) create_data_offline=generate_data_offline)
run_time = timeit.default_timer() - st run_time = timeit.default_timer() - st
logging.info("Data preprocessing complete. Time: {:.1f} sec." logging.info(
.format(run_time)) "Data preprocessing complete. Time: {:.1f} sec.".format(run_time))
print(producer) print(producer)
return num_users, num_items, producer return num_users, num_items, producer
...@@ -23,6 +23,7 @@ import hashlib ...@@ -23,6 +23,7 @@ import hashlib
import os import os
import mock import mock
import numpy as np import numpy as np
import scipy.stats import scipy.stats
import tensorflow as tf import tensorflow as tf
...@@ -32,7 +33,6 @@ from official.recommendation import data_preprocessing ...@@ -32,7 +33,6 @@ from official.recommendation import data_preprocessing
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import popen_helper from official.recommendation import popen_helper
DATASET = "ml-test" DATASET = "ml-test"
NUM_USERS = 1000 NUM_USERS = 1000
NUM_ITEMS = 2000 NUM_ITEMS = 2000
...@@ -41,7 +41,6 @@ BATCH_SIZE = 2048 ...@@ -41,7 +41,6 @@ BATCH_SIZE = 2048
EVAL_BATCH_SIZE = 4000 EVAL_BATCH_SIZE = 4000
NUM_NEG = 4 NUM_NEG = 4
END_TO_END_TRAIN_MD5 = "b218738e915e825d03939c5e305a2698" END_TO_END_TRAIN_MD5 = "b218738e915e825d03939c5e305a2698"
END_TO_END_EVAL_MD5 = "d753d0f3186831466d6e218163a9501e" END_TO_END_EVAL_MD5 = "d753d0f3186831466d6e218163a9501e"
FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22" FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22"
...@@ -136,8 +135,11 @@ class BaseTest(tf.test.TestCase): ...@@ -136,8 +135,11 @@ class BaseTest(tf.test.TestCase):
def _test_end_to_end(self, constructor_type): def _test_end_to_end(self, constructor_type):
params = self.make_params(train_epochs=1) params = self.make_params(train_epochs=1)
_, _, producer = data_preprocessing.instantiate_pipeline( _, _, producer = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, params=params, dataset=DATASET,
constructor_type=constructor_type, deterministic=True) data_dir=self.temp_data_dir,
params=params,
constructor_type=constructor_type,
deterministic=True)
producer.start() producer.start()
producer.join() producer.join()
...@@ -258,8 +260,11 @@ class BaseTest(tf.test.TestCase): ...@@ -258,8 +260,11 @@ class BaseTest(tf.test.TestCase):
train_epochs = 5 train_epochs = 5
params = self.make_params(train_epochs=train_epochs) params = self.make_params(train_epochs=train_epochs)
_, _, producer = data_preprocessing.instantiate_pipeline( _, _, producer = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, params=params, dataset=DATASET,
constructor_type=constructor_type, deterministic=True) data_dir=self.temp_data_dir,
params=params,
constructor_type=constructor_type,
deterministic=True)
producer.start() producer.start()
...@@ -298,8 +303,8 @@ class BaseTest(tf.test.TestCase): ...@@ -298,8 +303,8 @@ class BaseTest(tf.test.TestCase):
self.assertRegexpMatches(md5.hexdigest(), FRESH_RANDOMNESS_MD5) self.assertRegexpMatches(md5.hexdigest(), FRESH_RANDOMNESS_MD5)
# The positive examples should appear exactly once each epoch # The positive examples should appear exactly once each epoch
self.assertAllEqual(list(positive_counts.values()), self.assertAllEqual(
[train_epochs for _ in positive_counts]) list(positive_counts.values()), [train_epochs for _ in positive_counts])
# The threshold for the negatives is heuristic, but in general repeats are # The threshold for the negatives is heuristic, but in general repeats are
# expected, but should not appear too frequently. # expected, but should not appear too frequently.
...@@ -317,8 +322,8 @@ class BaseTest(tf.test.TestCase): ...@@ -317,8 +322,8 @@ class BaseTest(tf.test.TestCase):
# The frequency of occurance of a given negative pair should follow an # The frequency of occurance of a given negative pair should follow an
# approximately binomial distribution in the limit that the cardinality of # approximately binomial distribution in the limit that the cardinality of
# the negative pair set >> number of samples per epoch. # the negative pair set >> number of samples per epoch.
approx_pdf = scipy.stats.binom.pmf(k=np.arange(train_epochs+1), approx_pdf = scipy.stats.binom.pmf(
n=train_epochs, p=e_sample) k=np.arange(train_epochs + 1), n=train_epochs, p=e_sample)
# Tally the actual observed counts. # Tally the actual observed counts.
count_distribution = [0 for _ in range(train_epochs + 1)] count_distribution = [0 for _ in range(train_epochs + 1)]
......
...@@ -27,6 +27,7 @@ import tempfile ...@@ -27,6 +27,7 @@ import tempfile
import zipfile import zipfile
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
# Import libraries
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import six import six
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common functionalities used by both Keras and Estimator implementations. """Common functionalities used by both Keras and Estimator implementations."""
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -23,6 +22,7 @@ import json ...@@ -23,6 +22,7 @@ import json
import os import os
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import numpy as np import numpy as np
from absl import flags from absl import flags
from absl import logging from absl import logging
...@@ -56,7 +56,9 @@ def get_inputs(params): ...@@ -56,7 +56,9 @@ def get_inputs(params):
num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
else: else:
num_users, num_items, producer = data_preprocessing.instantiate_pipeline( num_users, num_items, producer = data_preprocessing.instantiate_pipeline(
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params, dataset=FLAGS.dataset,
data_dir=FLAGS.data_dir,
params=params,
constructor_type=FLAGS.constructor_type, constructor_type=FLAGS.constructor_type,
deterministic=FLAGS.seed is not None) deterministic=FLAGS.seed is not None)
num_train_steps = producer.train_batches_per_epoch num_train_steps = producer.train_batches_per_epoch
...@@ -108,16 +110,17 @@ def get_v1_distribution_strategy(params): ...@@ -108,16 +110,17 @@ def get_v1_distribution_strategy(params):
"""Returns the distribution strategy to use.""" """Returns the distribution strategy to use."""
if params["use_tpu"]: if params["use_tpu"]:
# Some of the networking libraries are quite chatty. # Some of the networking libraries are quite chatty.
for name in ["googleapiclient.discovery", "googleapiclient.discovery_cache", for name in [
"oauth2client.transport"]: "googleapiclient.discovery", "googleapiclient.discovery_cache",
"oauth2client.transport"
]:
logging.getLogger(name).setLevel(logging.ERROR) logging.getLogger(name).setLevel(logging.ERROR)
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=params["tpu"], tpu=params["tpu"],
zone=params["tpu_zone"], zone=params["tpu_zone"],
project=params["tpu_gcp_project"], project=params["tpu_gcp_project"],
coordinator_name="coordinator" coordinator_name="coordinator")
)
logging.info("Issuing reset command to TPU to ensure a clean state.") logging.info("Issuing reset command to TPU to ensure a clean state.")
tf.Session.reset(tpu_cluster_resolver.get_master()) tf.Session.reset(tpu_cluster_resolver.get_master())
...@@ -126,10 +129,12 @@ def get_v1_distribution_strategy(params): ...@@ -126,10 +129,12 @@ def get_v1_distribution_strategy(params):
# by reading the `TF_CONFIG` environment variable, and the coordinator # by reading the `TF_CONFIG` environment variable, and the coordinator
# is used by StreamingFilesDataset. # is used by StreamingFilesDataset.
tf_config_env = { tf_config_env = {
"session_master": tpu_cluster_resolver.get_master(), "session_master":
"eval_session_master": tpu_cluster_resolver.get_master(), tpu_cluster_resolver.get_master(),
"coordinator": tpu_cluster_resolver.cluster_spec() "eval_session_master":
.as_dict()["coordinator"] tpu_cluster_resolver.get_master(),
"coordinator":
tpu_cluster_resolver.cluster_spec().as_dict()["coordinator"]
} }
os.environ["TF_CONFIG"] = json.dumps(tf_config_env) os.environ["TF_CONFIG"] = json.dumps(tf_config_env)
...@@ -146,10 +151,16 @@ def get_v1_distribution_strategy(params): ...@@ -146,10 +151,16 @@ def get_v1_distribution_strategy(params):
def define_ncf_flags(): def define_ncf_flags():
"""Add flags for running ncf_main.""" """Add flags for running ncf_main."""
# Add common flags # Add common flags
flags_core.define_base(model_dir=True, clean=True, train_epochs=True, flags_core.define_base(
epochs_between_evals=True, export_dir=False, model_dir=True,
run_eagerly=True, stop_threshold=True, num_gpu=True, clean=True,
distribution_strategy=True) train_epochs=True,
epochs_between_evals=True,
export_dir=False,
run_eagerly=True,
stop_threshold=True,
num_gpu=True,
distribution_strategy=True)
flags_core.define_performance( flags_core.define_performance(
synthetic_data=True, synthetic_data=True,
dtype=True, dtype=True,
...@@ -171,69 +182,82 @@ def define_ncf_flags(): ...@@ -171,69 +182,82 @@ def define_ncf_flags():
dataset=movielens.ML_1M, dataset=movielens.ML_1M,
train_epochs=2, train_epochs=2,
batch_size=99000, batch_size=99000,
tpu=None tpu=None)
)
# Add ncf-specific flags # Add ncf-specific flags
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap( name="download_if_missing",
default=True,
help=flags_core.help_wrap(
"Download data to data_dir if it is not already present.")) "Download data to data_dir if it is not already present."))
flags.DEFINE_integer( flags.DEFINE_integer(
name="eval_batch_size", default=None, help=flags_core.help_wrap( name="eval_batch_size",
default=None,
help=flags_core.help_wrap(
"The batch size used for evaluation. This should generally be larger" "The batch size used for evaluation. This should generally be larger"
"than the training batch size as the lack of back propagation during" "than the training batch size as the lack of back propagation during"
"evaluation can allow for larger batch sizes to fit in memory. If not" "evaluation can allow for larger batch sizes to fit in memory. If not"
"specified, the training batch size (--batch_size) will be used.")) "specified, the training batch size (--batch_size) will be used."))
flags.DEFINE_integer( flags.DEFINE_integer(
name="num_factors", default=8, name="num_factors",
default=8,
help=flags_core.help_wrap("The Embedding size of MF model.")) help=flags_core.help_wrap("The Embedding size of MF model."))
# Set the default as a list of strings to be consistent with input arguments # Set the default as a list of strings to be consistent with input arguments
flags.DEFINE_list( flags.DEFINE_list(
name="layers", default=["64", "32", "16", "8"], name="layers",
default=["64", "32", "16", "8"],
help=flags_core.help_wrap( help=flags_core.help_wrap(
"The sizes of hidden layers for MLP. Example " "The sizes of hidden layers for MLP. Example "
"to specify different sizes of MLP layers: --layers=32,16,8,4")) "to specify different sizes of MLP layers: --layers=32,16,8,4"))
flags.DEFINE_float( flags.DEFINE_float(
name="mf_regularization", default=0., name="mf_regularization",
default=0.,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"The regularization factor for MF embeddings. The factor is used by " "The regularization factor for MF embeddings. The factor is used by "
"regularizer which allows to apply penalties on layer parameters or " "regularizer which allows to apply penalties on layer parameters or "
"layer activity during optimization.")) "layer activity during optimization."))
flags.DEFINE_list( flags.DEFINE_list(
name="mlp_regularization", default=["0.", "0.", "0.", "0."], name="mlp_regularization",
default=["0.", "0.", "0.", "0."],
help=flags_core.help_wrap( help=flags_core.help_wrap(
"The regularization factor for each MLP layer. See mf_regularization " "The regularization factor for each MLP layer. See mf_regularization "
"help for more info about regularization factor.")) "help for more info about regularization factor."))
flags.DEFINE_integer( flags.DEFINE_integer(
name="num_neg", default=4, name="num_neg",
default=4,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"The Number of negative instances to pair with a positive instance.")) "The Number of negative instances to pair with a positive instance."))
flags.DEFINE_float( flags.DEFINE_float(
name="learning_rate", default=0.001, name="learning_rate",
default=0.001,
help=flags_core.help_wrap("The learning rate.")) help=flags_core.help_wrap("The learning rate."))
flags.DEFINE_float( flags.DEFINE_float(
name="beta1", default=0.9, name="beta1",
default=0.9,
help=flags_core.help_wrap("beta1 hyperparameter for the Adam optimizer.")) help=flags_core.help_wrap("beta1 hyperparameter for the Adam optimizer."))
flags.DEFINE_float( flags.DEFINE_float(
name="beta2", default=0.999, name="beta2",
default=0.999,
help=flags_core.help_wrap("beta2 hyperparameter for the Adam optimizer.")) help=flags_core.help_wrap("beta2 hyperparameter for the Adam optimizer."))
flags.DEFINE_float( flags.DEFINE_float(
name="epsilon", default=1e-8, name="epsilon",
default=1e-8,
help=flags_core.help_wrap("epsilon hyperparameter for the Adam " help=flags_core.help_wrap("epsilon hyperparameter for the Adam "
"optimizer.")) "optimizer."))
flags.DEFINE_float( flags.DEFINE_float(
name="hr_threshold", default=1.0, name="hr_threshold",
default=1.0,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"If passed, training will stop when the evaluation metric HR is " "If passed, training will stop when the evaluation metric HR is "
"greater than or equal to hr_threshold. For dataset ml-1m, the " "greater than or equal to hr_threshold. For dataset ml-1m, the "
...@@ -242,8 +266,10 @@ def define_ncf_flags(): ...@@ -242,8 +266,10 @@ def define_ncf_flags():
"achieved by MLPerf implementation.")) "achieved by MLPerf implementation."))
flags.DEFINE_enum( flags.DEFINE_enum(
name="constructor_type", default="bisection", name="constructor_type",
enum_values=["bisection", "materialized"], case_sensitive=False, default="bisection",
enum_values=["bisection", "materialized"],
case_sensitive=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"Strategy to use for generating false negatives. materialized has a" "Strategy to use for generating false negatives. materialized has a"
"precompute that scales badly, but a faster per-epoch construction" "precompute that scales badly, but a faster per-epoch construction"
...@@ -265,7 +291,8 @@ def define_ncf_flags(): ...@@ -265,7 +291,8 @@ def define_ncf_flags():
help=flags_core.help_wrap("Path to input meta data file.")) help=flags_core.help_wrap("Path to input meta data file."))
flags.DEFINE_bool( flags.DEFINE_bool(
name="ml_perf", default=False, name="ml_perf",
default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"If set, changes the behavior of the model slightly to match the " "If set, changes the behavior of the model slightly to match the "
"MLPerf reference implementations here: \n" "MLPerf reference implementations here: \n"
...@@ -280,23 +307,26 @@ def define_ncf_flags(): ...@@ -280,23 +307,26 @@ def define_ncf_flags():
"not stable.")) "not stable."))
flags.DEFINE_bool( flags.DEFINE_bool(
name="output_ml_perf_compliance_logging", default=False, name="output_ml_perf_compliance_logging",
default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"If set, output the MLPerf compliance logging. This is only useful " "If set, output the MLPerf compliance logging. This is only useful "
"if one is running the model for MLPerf. See " "if one is running the model for MLPerf. See "
"https://github.com/mlperf/policies/blob/master/training_rules.adoc" "https://github.com/mlperf/policies/blob/master/training_rules.adoc"
"#submission-compliance-logs for details. This uses sudo and so may " "#submission-compliance-logs for details. This uses sudo and so may "
"ask for your password, as root access is needed to clear the system " "ask for your password, as root access is needed to clear the system "
"caches, which is required for MLPerf compliance." "caches, which is required for MLPerf compliance."))
)
)
flags.DEFINE_integer( flags.DEFINE_integer(
name="seed", default=None, help=flags_core.help_wrap( name="seed",
default=None,
help=flags_core.help_wrap(
"This value will be used to seed both NumPy and TensorFlow.")) "This value will be used to seed both NumPy and TensorFlow."))
@flags.validator("eval_batch_size", "eval_batch_size must be at least {}" @flags.validator(
.format(rconst.NUM_EVAL_NEGATIVES + 1)) "eval_batch_size",
"eval_batch_size must be at least {}".format(rconst.NUM_EVAL_NEGATIVES +
1))
def eval_size_check(eval_batch_size): def eval_size_check(eval_batch_size):
return (eval_batch_size is None or return (eval_batch_size is None or
int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES) int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES)
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import functools import functools
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
...@@ -130,8 +131,8 @@ def create_ncf_input_data(params, ...@@ -130,8 +131,8 @@ def create_ncf_input_data(params,
from tf record files. Must be specified when params["train_input_dataset"] from tf record files. Must be specified when params["train_input_dataset"]
is specified. is specified.
strategy: Distribution strategy used for distributed training. If specified, strategy: Distribution strategy used for distributed training. If specified,
used to assert that evaluation batch size is correctly a multiple of used to assert that evaluation batch size is correctly a multiple of total
total number of devices used. number of devices used.
Returns: Returns:
(training dataset, evaluation dataset, train steps per epoch, (training dataset, evaluation dataset, train steps per epoch,
......
...@@ -26,6 +26,7 @@ import json ...@@ -26,6 +26,7 @@ import json
import os import os
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
...@@ -42,7 +43,6 @@ from official.utils.misc import distribution_utils ...@@ -42,7 +43,6 @@ from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -50,9 +50,7 @@ def metric_fn(logits, dup_mask, match_mlperf): ...@@ -50,9 +50,7 @@ def metric_fn(logits, dup_mask, match_mlperf):
dup_mask = tf.cast(dup_mask, tf.float32) dup_mask = tf.cast(dup_mask, tf.float32)
logits = tf.slice(logits, [0, 1], [-1, -1]) logits = tf.slice(logits, [0, 1], [-1, -1])
in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg( in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
logits, logits, dup_mask, match_mlperf)
dup_mask,
match_mlperf)
metric_weights = tf.cast(metric_weights, tf.float32) metric_weights = tf.cast(metric_weights, tf.float32)
return in_top_k, metric_weights return in_top_k, metric_weights
...@@ -152,9 +150,10 @@ class CustomEarlyStopping(tf.keras.callbacks.Callback): ...@@ -152,9 +150,10 @@ class CustomEarlyStopping(tf.keras.callbacks.Callback):
logs = logs or {} logs = logs or {}
monitor_value = logs.get(self.monitor) monitor_value = logs.get(self.monitor)
if monitor_value is None: if monitor_value is None:
logging.warning("Early stopping conditioned on metric `%s` " logging.warning(
"which is not available. Available metrics are: %s", "Early stopping conditioned on metric `%s` "
self.monitor, ",".join(list(logs.keys()))) "which is not available. Available metrics are: %s", self.monitor,
",".join(list(logs.keys())))
return monitor_value return monitor_value
...@@ -181,12 +180,9 @@ def _get_keras_model(params): ...@@ -181,12 +180,9 @@ def _get_keras_model(params):
logits = base_model.output logits = base_model.output
zeros = tf.keras.layers.Lambda( zeros = tf.keras.layers.Lambda(lambda x: x * 0)(logits)
lambda x: x * 0)(logits)
softmax_logits = tf.keras.layers.concatenate( softmax_logits = tf.keras.layers.concatenate([zeros, logits], axis=-1)
[zeros, logits],
axis=-1)
# Custom training loop calculates loss and metric as a part of # Custom training loop calculates loss and metric as a part of
# training/evaluation step function. # training/evaluation step function.
...@@ -204,7 +200,8 @@ def _get_keras_model(params): ...@@ -204,7 +200,8 @@ def _get_keras_model(params):
movielens.ITEM_COLUMN: item_input, movielens.ITEM_COLUMN: item_input,
rconst.VALID_POINT_MASK: valid_pt_mask_input, rconst.VALID_POINT_MASK: valid_pt_mask_input,
rconst.DUPLICATE_MASK: dup_mask_input, rconst.DUPLICATE_MASK: dup_mask_input,
rconst.TRAIN_LABEL_KEY: label_input}, rconst.TRAIN_LABEL_KEY: label_input
},
outputs=softmax_logits) outputs=softmax_logits)
keras_model.summary() keras_model.summary()
...@@ -412,8 +409,7 @@ def run_ncf_custom_training(params, ...@@ -412,8 +409,7 @@ def run_ncf_custom_training(params,
optimizer.apply_gradients(grads) optimizer.apply_gradients(grads)
return loss return loss
per_replica_losses = strategy.run( per_replica_losses = strategy.run(step_fn, args=(next(train_iterator),))
step_fn, args=(next(train_iterator),))
mean_loss = strategy.reduce( mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss return mean_loss
...@@ -432,8 +428,7 @@ def run_ncf_custom_training(params, ...@@ -432,8 +428,7 @@ def run_ncf_custom_training(params,
return hr_sum, hr_count return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = ( per_replica_hr_sum, per_replica_hr_count = (
strategy.run( strategy.run(step_fn, args=(next(eval_iterator),)))
step_fn, args=(next(eval_iterator),)))
hr_sum = strategy.reduce( hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None) tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
hr_count = strategy.reduce( hr_count = strategy.reduce(
...@@ -482,8 +477,8 @@ def run_ncf_custom_training(params, ...@@ -482,8 +477,8 @@ def run_ncf_custom_training(params,
# Write train loss once in every 1000 steps. # Write train loss once in every 1000 steps.
if train_summary_writer and step % 1000 == 0: if train_summary_writer and step % 1000 == 0:
with train_summary_writer.as_default(): with train_summary_writer.as_default():
tf.summary.scalar("training_loss", train_loss/(step + 1), tf.summary.scalar(
step=current_step) "training_loss", train_loss / (step + 1), step=current_step)
for c in callbacks: for c in callbacks:
c.on_batch_end(current_step) c.on_batch_end(current_step)
...@@ -552,7 +547,7 @@ def build_stats(loss, eval_result, time_callback): ...@@ -552,7 +547,7 @@ def build_stats(loss, eval_result, time_callback):
if len(timestamp_log) > 1: if len(timestamp_log) > 1:
stats["avg_exp_per_second"] = ( stats["avg_exp_per_second"] = (
time_callback.batch_size * time_callback.log_steps * time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log)-1) / (len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp)) (timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats return stats
......
...@@ -48,64 +48,68 @@ class NcfTest(tf.test.TestCase): ...@@ -48,64 +48,68 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1'] _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_no_dist_strat(self): def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main,
tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off']) ['-distribution_strategy', 'off'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_dist_strat(self): def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main,
tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_dist_strat_ctl(self): def test_end_to_end_keras_dist_strat_ctl(self):
flags = (self._BASE_END_TO_END_FLAGS + flags = (
['-num_gpus', '0'] + self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'] +
['-keras_use_ctl', 'True']) ['-keras_use_ctl', 'True'])
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(), extra_flags=flags)
extra_flags=flags)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_1_gpu_dist_strat_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
self.skipTest( self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available". '{} GPUs are not available for this test. {} GPUs are available'
format(1, context.num_gpus())) .format(1, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main,
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1', tmp_root=self.get_temp_dir(),
'--dtype', 'fp16']) extra_flags=self._BASE_END_TO_END_FLAGS +
['-num_gpus', '1', '--dtype', 'fp16'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
self.skipTest( self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'. '{} GPUs are not available for this test. {} GPUs are available'
format(1, context.num_gpus())) .format(1, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main,
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1', tmp_root=self.get_temp_dir(),
'--dtype', 'fp16', extra_flags=self._BASE_END_TO_END_FLAGS +
'--keras_use_ctl']) ['-num_gpus', '1', '--dtype', 'fp16', '--keras_use_ctl'])
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_2_gpu_fp16(self): def test_end_to_end_keras_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available". '{} GPUs are not available for this test. {} GPUs are available'
format(2, context.num_gpus())) .format(2, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main,
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2', tmp_root=self.get_temp_dir(),
'--dtype', 'fp16']) extra_flags=self._BASE_END_TO_END_FLAGS +
['-num_gpus', '2', '--dtype', 'fp16'])
if __name__ == "__main__": if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -111,8 +111,7 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -111,8 +111,7 @@ def neumf_model_fn(features, labels, mode, params):
loss = tf.compat.v1.losses.sparse_softmax_cross_entropy( loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(
labels=labels, labels=labels,
logits=softmax_logits, logits=softmax_logits,
weights=tf.cast(valid_pt_mask, tf.float32) weights=tf.cast(valid_pt_mask, tf.float32))
)
tf.identity(loss, name="cross_entropy") tf.identity(loss, name="cross_entropy")
...@@ -196,15 +195,19 @@ def construct_model(user_input: tf.Tensor, item_input: tf.Tensor, ...@@ -196,15 +195,19 @@ def construct_model(user_input: tf.Tensor, item_input: tf.Tensor,
# GMF part # GMF part
mf_user_latent = tf.keras.layers.Lambda( mf_user_latent = tf.keras.layers.Lambda(
mf_slice_fn, name="embedding_user_mf")(embedding_user) mf_slice_fn, name="embedding_user_mf")(
embedding_user)
mf_item_latent = tf.keras.layers.Lambda( mf_item_latent = tf.keras.layers.Lambda(
mf_slice_fn, name="embedding_item_mf")(embedding_item) mf_slice_fn, name="embedding_item_mf")(
embedding_item)
# MLP part # MLP part
mlp_user_latent = tf.keras.layers.Lambda( mlp_user_latent = tf.keras.layers.Lambda(
mlp_slice_fn, name="embedding_user_mlp")(embedding_user) mlp_slice_fn, name="embedding_user_mlp")(
embedding_user)
mlp_item_latent = tf.keras.layers.Lambda( mlp_item_latent = tf.keras.layers.Lambda(
mlp_slice_fn, name="embedding_item_mlp")(embedding_item) mlp_slice_fn, name="embedding_item_mlp")(
embedding_item)
# Element-wise multiply # Element-wise multiply
mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent]) mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
...@@ -225,8 +228,11 @@ def construct_model(user_input: tf.Tensor, item_input: tf.Tensor, ...@@ -225,8 +228,11 @@ def construct_model(user_input: tf.Tensor, item_input: tf.Tensor,
# Final prediction layer # Final prediction layer
logits = tf.keras.layers.Dense( logits = tf.keras.layers.Dense(
1, activation=None, kernel_initializer="lecun_uniform", 1,
name=movielens.RATING_COLUMN)(predict_vector) activation=None,
kernel_initializer="lecun_uniform",
name=movielens.RATING_COLUMN)(
predict_vector)
# Print model topology. # Print model topology.
model = tf.keras.models.Model([user_input, item_input], logits) model = tf.keras.models.Model([user_input, item_input], logits)
...@@ -263,8 +269,7 @@ def _get_estimator_spec_with_metrics(logits: tf.Tensor, ...@@ -263,8 +269,7 @@ def _get_estimator_spec_with_metrics(logits: tf.Tensor,
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL, mode=tf.estimator.ModeKeys.EVAL,
loss=cross_entropy, loss=cross_entropy,
eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights) eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights))
)
def compute_eval_loss_and_metrics_helper(logits: tf.Tensor, def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
...@@ -335,9 +340,13 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor, ...@@ -335,9 +340,13 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
# Examples are provided by the eval Dataset in a structured format, so eval # Examples are provided by the eval Dataset in a structured format, so eval
# labels can be reconstructed on the fly. # labels can be reconstructed on the fly.
eval_labels = tf.reshape(shape=(-1,), tensor=tf.one_hot( eval_labels = tf.reshape(
tf.zeros(shape=(logits_by_user.shape[0],), dtype=tf.int32) + shape=(-1,),
rconst.NUM_EVAL_NEGATIVES, logits_by_user.shape[1], dtype=tf.int32)) tensor=tf.one_hot(
tf.zeros(shape=(logits_by_user.shape[0],), dtype=tf.int32) +
rconst.NUM_EVAL_NEGATIVES,
logits_by_user.shape[1],
dtype=tf.int32))
eval_labels_float = tf.cast(eval_labels, tf.float32) eval_labels_float = tf.cast(eval_labels, tf.float32)
...@@ -346,13 +355,14 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor, ...@@ -346,13 +355,14 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
# weights for the negative examples we compute a loss which is consistent with # weights for the negative examples we compute a loss which is consistent with
# the training data. (And provides apples-to-apples comparison) # the training data. (And provides apples-to-apples comparison)
negative_scale_factor = num_training_neg / rconst.NUM_EVAL_NEGATIVES negative_scale_factor = num_training_neg / rconst.NUM_EVAL_NEGATIVES
example_weights = ( example_weights = ((eval_labels_float +
(eval_labels_float + (1 - eval_labels_float) * negative_scale_factor) * (1 - eval_labels_float) * negative_scale_factor) *
(1 + rconst.NUM_EVAL_NEGATIVES) / (1 + num_training_neg)) (1 + rconst.NUM_EVAL_NEGATIVES) / (1 + num_training_neg))
# Tile metric weights back to logit dimensions # Tile metric weights back to logit dimensions
expanded_metric_weights = tf.reshape(tf.tile( expanded_metric_weights = tf.reshape(
metric_weights[:, tf.newaxis], (1, rconst.NUM_EVAL_NEGATIVES + 1)), (-1,)) tf.tile(metric_weights[:, tf.newaxis],
(1, rconst.NUM_EVAL_NEGATIVES + 1)), (-1,))
# ignore padded examples # ignore padded examples
example_weights *= tf.cast(expanded_metric_weights, tf.float32) example_weights *= tf.cast(expanded_metric_weights, tf.float32)
...@@ -362,12 +372,15 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor, ...@@ -362,12 +372,15 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
def metric_fn(top_k_tensor, ndcg_tensor, weight_tensor): def metric_fn(top_k_tensor, ndcg_tensor, weight_tensor):
return { return {
rconst.HR_KEY: tf.compat.v1.metrics.mean(top_k_tensor, rconst.HR_KEY:
weights=weight_tensor, tf.compat.v1.metrics.mean(
name=rconst.HR_METRIC_NAME), top_k_tensor, weights=weight_tensor,
rconst.NDCG_KEY: tf.compat.v1.metrics.mean(ndcg_tensor, name=rconst.HR_METRIC_NAME),
weights=weight_tensor, rconst.NDCG_KEY:
name=rconst.NDCG_METRIC_NAME) tf.compat.v1.metrics.mean(
ndcg_tensor,
weights=weight_tensor,
name=rconst.NDCG_METRIC_NAME)
} }
return cross_entropy, metric_fn, in_top_k, ndcg, metric_weights return cross_entropy, metric_fn, in_top_k, ndcg, metric_weights
...@@ -405,27 +418,26 @@ def compute_top_k_and_ndcg(logits: tf.Tensor, ...@@ -405,27 +418,26 @@ def compute_top_k_and_ndcg(logits: tf.Tensor,
# Determine the location of the first element in each row after the elements # Determine the location of the first element in each row after the elements
# are sorted. # are sorted.
sort_indices = tf.argsort( sort_indices = tf.argsort(logits_by_user, axis=1, direction="DESCENDING")
logits_by_user, axis=1, direction="DESCENDING")
# Use matrix multiplication to extract the position of the true item from the # Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs # tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere. # perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in # However this is a special case because the target will only appear in
# sort_indices once. # sort_indices once.
one_hot_position = tf.cast(tf.equal(sort_indices, rconst.NUM_EVAL_NEGATIVES), one_hot_position = tf.cast(
tf.int32) tf.equal(sort_indices, rconst.NUM_EVAL_NEGATIVES), tf.int32)
sparse_positions = tf.multiply( sparse_positions = tf.multiply(
one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :]) one_hot_position,
tf.range(logits_by_user.shape[1])[tf.newaxis, :])
position_vector = tf.reduce_sum(sparse_positions, axis=1) position_vector = tf.reduce_sum(sparse_positions, axis=1)
in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32) in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
ndcg = tf.math.log(2.) / tf.math.log( ndcg = tf.math.log(2.) / tf.math.log(tf.cast(position_vector, tf.float32) + 2)
tf.cast(position_vector, tf.float32) + 2)
ndcg *= in_top_k ndcg *= in_top_k
# If a row is a padded row, all but the first element will be a duplicate. # If a row is a padded row, all but the first element will be a duplicate.
metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1), metric_weights = tf.not_equal(
rconst.NUM_EVAL_NEGATIVES) tf.reduce_sum(duplicate_mask_by_user, axis=1), rconst.NUM_EVAL_NEGATIVES)
return in_top_k, ndcg, metric_weights, logits_by_user return in_top_k, ndcg, metric_weights, logits_by_user
...@@ -37,9 +37,7 @@ def permutation(args): ...@@ -37,9 +37,7 @@ def permutation(args):
args: A size two tuple that will unpacked into the size of the permutation args: A size two tuple that will unpacked into the size of the permutation
and the random seed. This form is used because starmap is not universally and the random seed. This form is used because starmap is not universally
available. available.
returns: A NumPy array containing a random permutation.
returns:
A NumPy array containing a random permutation.
""" """
x, seed = args x, seed = args
...@@ -53,8 +51,11 @@ def permutation(args): ...@@ -53,8 +51,11 @@ def permutation(args):
def very_slightly_biased_randint(max_val_vector): def very_slightly_biased_randint(max_val_vector):
sample_dtype = np.uint64 sample_dtype = np.uint64
out_dtype = max_val_vector.dtype out_dtype = max_val_vector.dtype
samples = np.random.randint(low=0, high=np.iinfo(sample_dtype).max, samples = np.random.randint(
size=max_val_vector.shape, dtype=sample_dtype) low=0,
high=np.iinfo(sample_dtype).max,
size=max_val_vector.shape,
dtype=sample_dtype)
return np.mod(samples, max_val_vector.astype(sample_dtype)).astype(out_dtype) return np.mod(samples, max_val_vector.astype(sample_dtype)).astype(out_dtype)
...@@ -88,5 +89,5 @@ def mask_duplicates(x, axis=1): # type: (np.ndarray, int) -> np.ndarray ...@@ -88,5 +89,5 @@ def mask_duplicates(x, axis=1): # type: (np.ndarray, int) -> np.ndarray
# Duplicate values will have a difference of zero. By definition the first # Duplicate values will have a difference of zero. By definition the first
# element is never a duplicate. # element is never a duplicate.
return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis], return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis], inv_x_sort_ind],
inv_x_sort_ind], 0, 1) 0, 1)
...@@ -103,9 +103,9 @@ def minimize_using_explicit_allreduce(tape, ...@@ -103,9 +103,9 @@ def minimize_using_explicit_allreduce(tape,
pre_allreduce_callbacks: A list of callback functions that takes gradients pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new and model variables pairs as input, manipulate them, and returns a new
gradients and model variables pairs. The callback functions will be gradients and model variables pairs. The callback functions will be
invoked in the list order and before gradients are allreduced. invoked in the list order and before gradients are allreduced. With
With mixed precision training, the pre_allreduce_allbacks will be mixed precision training, the pre_allreduce_allbacks will be applied on
applied on scaled_gradients. Default is no callbacks. scaled_gradients. Default is no callbacks.
post_allreduce_callbacks: A list of callback functions that takes post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback returns a new gradients and model variables paris. The callback
......
...@@ -23,10 +23,18 @@ import tensorflow as tf ...@@ -23,10 +23,18 @@ import tensorflow as tf
from official.utils.flags._conventions import help_wrap from official.utils.flags._conventions import help_wrap
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, def define_base(data_dir=True,
epochs_between_evals=False, stop_threshold=False, model_dir=True,
batch_size=True, num_gpu=False, hooks=False, export_dir=False, clean=False,
distribution_strategy=False, run_eagerly=False): train_epochs=False,
epochs_between_evals=False,
stop_threshold=False,
batch_size=True,
num_gpu=False,
hooks=False,
export_dir=False,
distribution_strategy=False,
run_eagerly=False):
"""Register base flags. """Register base flags.
Args: Args:
...@@ -35,8 +43,8 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, ...@@ -35,8 +43,8 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
clean: Create a flag for removing the model_dir. clean: Create a flag for removing the model_dir.
train_epochs: Create a flag to specify the number of training epochs. train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing. epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other stop_threshold: Create a flag to specify a threshold accuracy or other eval
eval metric which should trigger the end of training. metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size. batch_size: Create a flag to specify the batch size.
num_gpu: Create a flag to specify the number of GPUs used. num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
...@@ -44,6 +52,7 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, ...@@ -44,6 +52,7 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
distribution_strategy: Create a flag to specify which Distribution Strategy distribution_strategy: Create a flag to specify which Distribution Strategy
to use. to use.
run_eagerly: Create a flag to specify to run eagerly op by op. run_eagerly: Create a flag to specify to run eagerly op by op.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
""" """
...@@ -51,38 +60,48 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, ...@@ -51,38 +60,48 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
if data_dir: if data_dir:
flags.DEFINE_string( flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp", name="data_dir",
short_name="dd",
default="/tmp",
help=help_wrap("The location of the input data.")) help=help_wrap("The location of the input data."))
key_flags.append("data_dir") key_flags.append("data_dir")
if model_dir: if model_dir:
flags.DEFINE_string( flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp", name="model_dir",
short_name="md",
default="/tmp",
help=help_wrap("The location of the model checkpoint files.")) help=help_wrap("The location of the model checkpoint files."))
key_flags.append("model_dir") key_flags.append("model_dir")
if clean: if clean:
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="clean", default=False, name="clean",
default=False,
help=help_wrap("If set, model_dir will be removed if it exists.")) help=help_wrap("If set, model_dir will be removed if it exists."))
key_flags.append("clean") key_flags.append("clean")
if train_epochs: if train_epochs:
flags.DEFINE_integer( flags.DEFINE_integer(
name="train_epochs", short_name="te", default=1, name="train_epochs",
short_name="te",
default=1,
help=help_wrap("The number of epochs used to train.")) help=help_wrap("The number of epochs used to train."))
key_flags.append("train_epochs") key_flags.append("train_epochs")
if epochs_between_evals: if epochs_between_evals:
flags.DEFINE_integer( flags.DEFINE_integer(
name="epochs_between_evals", short_name="ebe", default=1, name="epochs_between_evals",
short_name="ebe",
default=1,
help=help_wrap("The number of training epochs to run between " help=help_wrap("The number of training epochs to run between "
"evaluations.")) "evaluations."))
key_flags.append("epochs_between_evals") key_flags.append("epochs_between_evals")
if stop_threshold: if stop_threshold:
flags.DEFINE_float( flags.DEFINE_float(
name="stop_threshold", short_name="st", name="stop_threshold",
short_name="st",
default=None, default=None,
help=help_wrap("If passed, training will stop at the earlier of " help=help_wrap("If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is " "train_epochs and when the evaluation metric is "
...@@ -90,7 +109,9 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, ...@@ -90,7 +109,9 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
if batch_size: if batch_size:
flags.DEFINE_integer( flags.DEFINE_integer(
name="batch_size", short_name="bs", default=32, name="batch_size",
short_name="bs",
default=32,
help=help_wrap("Batch size for training and evaluation. When using " help=help_wrap("Batch size for training and evaluation. When using "
"multiple gpus, this is the global batch size for " "multiple gpus, this is the global batch size for "
"all devices. For example, if the batch size is 32 " "all devices. For example, if the batch size is 32 "
...@@ -100,49 +121,52 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, ...@@ -100,49 +121,52 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
if num_gpu: if num_gpu:
flags.DEFINE_integer( flags.DEFINE_integer(
name="num_gpus", short_name="ng", name="num_gpus",
short_name="ng",
default=1, default=1,
help=help_wrap( help=help_wrap("How many GPUs to use at each worker with the "
"How many GPUs to use at each worker with the " "DistributionStrategies API. The default is 1."))
"DistributionStrategies API. The default is 1."))
if run_eagerly: if run_eagerly:
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="run_eagerly", default=False, name="run_eagerly",
default=False,
help="Run the model op by op without building a model function.") help="Run the model op by op without building a model function.")
if hooks: if hooks:
flags.DEFINE_list( flags.DEFINE_list(
name="hooks", short_name="hk", default="LoggingTensorHook", name="hooks",
short_name="hk",
default="LoggingTensorHook",
help=help_wrap( help=help_wrap(
u"A list of (case insensitive) strings to specify the names of " u"A list of (case insensitive) strings to specify the names of "
u"training hooks. Example: `--hooks ProfilerHook," u"training hooks. Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See hooks_helper " u"ExamplesPerSecondHook`\n See hooks_helper "
u"for details.") u"for details."))
)
key_flags.append("hooks") key_flags.append("hooks")
if export_dir: if export_dir:
flags.DEFINE_string( flags.DEFINE_string(
name="export_dir", short_name="ed", default=None, name="export_dir",
short_name="ed",
default=None,
help=help_wrap("If set, a SavedModel serialization of the model will " help=help_wrap("If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. " "be exported to this directory at the end of training. "
"See the README for more details and relevant links.") "See the README for more details and relevant links."))
)
key_flags.append("export_dir") key_flags.append("export_dir")
if distribution_strategy: if distribution_strategy:
flags.DEFINE_string( flags.DEFINE_string(
name="distribution_strategy", short_name="ds", default="mirrored", name="distribution_strategy",
short_name="ds",
default="mirrored",
help=help_wrap("The Distribution Strategy to use for training. " help=help_wrap("The Distribution Strategy to use for training. "
"Accepted values are 'off', 'one_device', " "Accepted values are 'off', 'one_device', "
"'mirrored', 'parameter_server', 'collective', " "'mirrored', 'parameter_server', 'collective', "
"case insensitive. 'off' means not to use " "case insensitive. 'off' means not to use "
"Distribution Strategy; 'default' means to choose " "Distribution Strategy; 'default' means to choose "
"from `MirroredStrategy` or `OneDeviceStrategy` " "from `MirroredStrategy` or `OneDeviceStrategy` "
"according to the number of GPUs.") "according to the number of GPUs."))
)
return key_flags return key_flags
......
...@@ -25,7 +25,8 @@ from official.utils.flags._conventions import help_wrap ...@@ -25,7 +25,8 @@ from official.utils.flags._conventions import help_wrap
def define_log_steps(): def define_log_steps():
flags.DEFINE_integer( flags.DEFINE_integer(
name="log_steps", default=100, name="log_steps",
default=100,
help="Frequency with which to log timing information with TimeHistory.") help="Frequency with which to log timing information with TimeHistory.")
return [] return []
...@@ -45,13 +46,16 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): ...@@ -45,13 +46,16 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
key_flags = [] key_flags = []
flags.DEFINE_enum( flags.DEFINE_enum(
name="benchmark_logger_type", default="BaseBenchmarkLogger", name="benchmark_logger_type",
default="BaseBenchmarkLogger",
enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger"], enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger"],
help=help_wrap("The type of benchmark logger to use. Defaults to using " help=help_wrap("The type of benchmark logger to use. Defaults to using "
"BaseBenchmarkLogger which logs to STDOUT. Different " "BaseBenchmarkLogger which logs to STDOUT. Different "
"loggers will require other flags to be able to work.")) "loggers will require other flags to be able to work."))
flags.DEFINE_string( flags.DEFINE_string(
name="benchmark_test_id", short_name="bti", default=None, name="benchmark_test_id",
short_name="bti",
default=None,
help=help_wrap("The unique test ID of the benchmark run. It could be the " help=help_wrap("The unique test ID of the benchmark run. It could be the "
"combination of key parameters. It is hardware " "combination of key parameters. It is hardware "
"independent and could be used compare the performance " "independent and could be used compare the performance "
...@@ -63,34 +67,43 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): ...@@ -63,34 +67,43 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
if benchmark_log_dir: if benchmark_log_dir:
flags.DEFINE_string( flags.DEFINE_string(
name="benchmark_log_dir", short_name="bld", default=None, name="benchmark_log_dir",
help=help_wrap("The location of the benchmark logging.") short_name="bld",
) default=None,
help=help_wrap("The location of the benchmark logging."))
if bigquery_uploader: if bigquery_uploader:
flags.DEFINE_string( flags.DEFINE_string(
name="gcp_project", short_name="gp", default=None, name="gcp_project",
short_name="gp",
default=None,
help=help_wrap( help=help_wrap(
"The GCP project name where the benchmark will be uploaded.")) "The GCP project name where the benchmark will be uploaded."))
flags.DEFINE_string( flags.DEFINE_string(
name="bigquery_data_set", short_name="bds", default="test_benchmark", name="bigquery_data_set",
short_name="bds",
default="test_benchmark",
help=help_wrap( help=help_wrap(
"The Bigquery dataset name where the benchmark will be uploaded.")) "The Bigquery dataset name where the benchmark will be uploaded."))
flags.DEFINE_string( flags.DEFINE_string(
name="bigquery_run_table", short_name="brt", default="benchmark_run", name="bigquery_run_table",
short_name="brt",
default="benchmark_run",
help=help_wrap("The Bigquery table name where the benchmark run " help=help_wrap("The Bigquery table name where the benchmark run "
"information will be uploaded.")) "information will be uploaded."))
flags.DEFINE_string( flags.DEFINE_string(
name="bigquery_run_status_table", short_name="brst", name="bigquery_run_status_table",
short_name="brst",
default="benchmark_run_status", default="benchmark_run_status",
help=help_wrap("The Bigquery table name where the benchmark run " help=help_wrap("The Bigquery table name where the benchmark run "
"status information will be uploaded.")) "status information will be uploaded."))
flags.DEFINE_string( flags.DEFINE_string(
name="bigquery_metric_table", short_name="bmt", name="bigquery_metric_table",
short_name="bmt",
default="benchmark_metric", default="benchmark_metric",
help=help_wrap("The Bigquery table name where the benchmark metric " help=help_wrap("The Bigquery table name where the benchmark metric "
"information will be uploaded.")) "information will be uploaded."))
...@@ -98,7 +111,7 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): ...@@ -98,7 +111,7 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
@flags.multi_flags_validator( @flags.multi_flags_validator(
["benchmark_logger_type", "benchmark_log_dir"], ["benchmark_logger_type", "benchmark_log_dir"],
message="--benchmark_logger_type=BenchmarkFileLogger will require " message="--benchmark_logger_type=BenchmarkFileLogger will require "
"--benchmark_log_dir being set") "--benchmark_log_dir being set")
def _check_benchmark_log_dir(flags_dict): def _check_benchmark_log_dir(flags_dict):
benchmark_logger_type = flags_dict["benchmark_logger_type"] benchmark_logger_type = flags_dict["benchmark_logger_type"]
if benchmark_logger_type == "BenchmarkFileLogger": if benchmark_logger_type == "BenchmarkFileLogger":
......
...@@ -25,13 +25,12 @@ import functools ...@@ -25,13 +25,12 @@ import functools
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
# This codifies help string conventions and makes it easy to update them if # This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the # necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies # line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*. # with help wrap when calling DEFINE_*.
_help_wrap = functools.partial(flags.text_wrap, length=80, indent="", _help_wrap = functools.partial(
firstline_indent="\n") flags.text_wrap, length=80, indent="", firstline_indent="\n")
# Pretty formatting causes issues when utf-8 is not installed on a system. # Pretty formatting causes issues when utf-8 is not installed on a system.
...@@ -46,6 +45,7 @@ def _stdout_utf8(): ...@@ -46,6 +45,7 @@ def _stdout_utf8():
if _stdout_utf8(): if _stdout_utf8():
help_wrap = _help_wrap help_wrap = _help_wrap
else: else:
def help_wrap(text, *args, **kwargs): def help_wrap(text, *args, **kwargs):
return _help_wrap(text, *args, **kwargs).replace(u"\ufeff", u"") return _help_wrap(text, *args, **kwargs).replace(u"\ufeff", u"")
......
...@@ -26,11 +26,13 @@ from official.utils.flags._conventions import help_wrap ...@@ -26,11 +26,13 @@ from official.utils.flags._conventions import help_wrap
def require_cloud_storage(flag_names): def require_cloud_storage(flag_names):
"""Register a validator to check directory flags. """Register a validator to check directory flags.
Args: Args:
flag_names: An iterable of strings containing the names of flags to be flag_names: An iterable of strings containing the names of flags to be
checked. checked.
""" """
msg = "TPU requires GCS path for {}".format(", ".join(flag_names)) msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
@flags.multi_flags_validator(["tpu"] + flag_names, message=msg) @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
def _path_check(flag_values): # pylint: disable=missing-docstring def _path_check(flag_values): # pylint: disable=missing-docstring
if flag_values["tpu"] is None: if flag_values["tpu"] is None:
...@@ -47,8 +49,10 @@ def require_cloud_storage(flag_names): ...@@ -47,8 +49,10 @@ def require_cloud_storage(flag_names):
def define_device(tpu=True): def define_device(tpu=True):
"""Register device specific flags. """Register device specific flags.
Args: Args:
tpu: Create flags to specify TPU operation. tpu: Create flags to specify TPU operation.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
""" """
...@@ -57,7 +61,8 @@ def define_device(tpu=True): ...@@ -57,7 +61,8 @@ def define_device(tpu=True):
if tpu: if tpu:
flags.DEFINE_string( flags.DEFINE_string(
name="tpu", default=None, name="tpu",
default=None,
help=help_wrap( help=help_wrap(
"The Cloud TPU to use for training. This should be either the name " "The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a " "used when creating the Cloud TPU, or a "
...@@ -66,20 +71,24 @@ def define_device(tpu=True): ...@@ -66,20 +71,24 @@ def define_device(tpu=True):
key_flags.append("tpu") key_flags.append("tpu")
flags.DEFINE_string( flags.DEFINE_string(
name="tpu_zone", default=None, name="tpu_zone",
default=None,
help=help_wrap( help=help_wrap(
"[Optional] GCE zone where the Cloud TPU is located in. If not " "[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE " "specified, we will attempt to automatically detect the GCE "
"project from metadata.")) "project from metadata."))
flags.DEFINE_string( flags.DEFINE_string(
name="tpu_gcp_project", default=None, name="tpu_gcp_project",
default=None,
help=help_wrap( help=help_wrap(
"[Optional] Project name for the Cloud TPU-enabled project. If not " "[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE " "specified, we will attempt to automatically detect the GCE "
"project from metadata.")) "project from metadata."))
flags.DEFINE_integer(name="num_tpu_shards", default=8, flags.DEFINE_integer(
help=help_wrap("Number of shards (TPU chips).")) name="num_tpu_shards",
default=8,
help=help_wrap("Number of shards (TPU chips)."))
return key_flags return key_flags
...@@ -38,7 +38,8 @@ def define_distribution(worker_hosts=True, task_index=True): ...@@ -38,7 +38,8 @@ def define_distribution(worker_hosts=True, task_index=True):
if worker_hosts: if worker_hosts:
flags.DEFINE_string( flags.DEFINE_string(
name='worker_hosts', default=None, name='worker_hosts',
default=None,
help=help_wrap( help=help_wrap(
'Comma-separated list of worker ip:port pairs for running ' 'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would ' 'multi-worker models with DistributionStrategy. The user would '
...@@ -47,7 +48,8 @@ def define_distribution(worker_hosts=True, task_index=True): ...@@ -47,7 +48,8 @@ def define_distribution(worker_hosts=True, task_index=True):
if task_index: if task_index:
flags.DEFINE_integer( flags.DEFINE_integer(
name='task_index', default=-1, name='task_index',
default=-1,
help=help_wrap('If multi-worker training, the task_index of this ' help=help_wrap('If multi-worker training, the task_index of this '
'worker.')) 'worker.'))
......
...@@ -37,7 +37,9 @@ def define_image(data_format=True): ...@@ -37,7 +37,9 @@ def define_image(data_format=True):
if data_format: if data_format:
flags.DEFINE_enum( flags.DEFINE_enum(
name="data_format", short_name="df", default=None, name="data_format",
short_name="df",
default=None,
enum_values=["channels_first", "channels_last"], enum_values=["channels_first", "channels_last"],
help=help_wrap( help=help_wrap(
"A flag to override the data format used in the model. " "A flag to override the data format used in the model. "
......
...@@ -20,12 +20,11 @@ from __future__ import print_function ...@@ -20,12 +20,11 @@ from __future__ import print_function
import multiprocessing import multiprocessing
from absl import flags # pylint: disable=g-bad-import-order from absl import flags # pylint: disable=g-bad-import-order
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags._conventions import help_wrap from official.utils.flags._conventions import help_wrap
# Map string to TensorFlow dtype # Map string to TensorFlow dtype
DTYPE_MAP = { DTYPE_MAP = {
"fp16": tf.float16, "fp16": tf.float16,
...@@ -55,15 +54,22 @@ def get_loss_scale(flags_obj, default_for_fp16): ...@@ -55,15 +54,22 @@ def get_loss_scale(flags_obj, default_for_fp16):
return default_for_fp16 return default_for_fp16
def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, def define_performance(num_parallel_calls=False,
synthetic_data=False, max_train_steps=False, dtype=False, inter_op=False,
all_reduce_alg=False, num_packs=False, intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=False,
all_reduce_alg=False,
num_packs=False,
tf_gpu_thread_mode=False, tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
datasets_num_parallel_batches=False, datasets_num_parallel_batches=False,
dynamic_loss_scale=False, fp16_implementation=False, dynamic_loss_scale=False,
fp16_implementation=False,
loss_scale=False, loss_scale=False,
tf_data_experimental_slack=False, enable_xla=False, tf_data_experimental_slack=False,
enable_xla=False,
training_dataset_cache=False): training_dataset_cache=False):
"""Register flags for specifying performance tuning arguments. """Register flags for specifying performance tuning arguments.
...@@ -72,8 +78,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -72,8 +78,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
inter_op: Create a flag to allow specification of inter op threads. inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads. intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data. synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number max_train_steps: Create a flags to allow specification of maximum number of
of training steps training steps
dtype: Create flags for specifying dtype. dtype: Create flags for specifying dtype.
all_reduce_alg: If set forces a specific algorithm for multi-gpu. all_reduce_alg: If set forces a specific algorithm for multi-gpu.
num_packs: If set provides number of packs for MirroredStrategy's cross num_packs: If set provides number of packs for MirroredStrategy's cross
...@@ -81,7 +87,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -81,7 +87,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
tf_gpu_thread_mode: gpu_private triggers us of private thread pool. tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
datasets_num_private_threads: Number of private threads for datasets. datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data. parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True. "dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag. fp16_implementation: Create fp16_implementation flag.
...@@ -91,8 +97,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -91,8 +97,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
`experimental_slack` option. `experimental_slack` option.
enable_xla: Determines if XLA (auto clustering) is turned on. enable_xla: Determines if XLA (auto clustering) is turned on.
training_dataset_cache: Whether to cache the training dataset on workers. training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in Typically used to improve training performance when training data is in
remote storage and can fit into worker memory. remote storage and can fit into worker memory.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
...@@ -101,7 +107,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -101,7 +107,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
key_flags = [] key_flags = []
if num_parallel_calls: if num_parallel_calls:
flags.DEFINE_integer( flags.DEFINE_integer(
name="num_parallel_calls", short_name="npc", name="num_parallel_calls",
short_name="npc",
default=multiprocessing.cpu_count(), default=multiprocessing.cpu_count(),
help=help_wrap("The number of records that are processed in parallel " help=help_wrap("The number of records that are processed in parallel "
"during input processing. This can be optimized per " "during input processing. This can be optimized per "
...@@ -111,20 +118,25 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -111,20 +118,25 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if inter_op: if inter_op:
flags.DEFINE_integer( flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0, name="inter_op_parallelism_threads",
short_name="inter",
default=0,
help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. " help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details.") "See TensorFlow config.proto for details."))
)
if intra_op: if intra_op:
flags.DEFINE_integer( flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0, name="intra_op_parallelism_threads",
short_name="intra",
default=0,
help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. " help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details.")) "See TensorFlow config.proto for details."))
if synthetic_data: if synthetic_data:
flags.DEFINE_bool( flags.DEFINE_bool(
name="use_synthetic_data", short_name="synth", default=False, name="use_synthetic_data",
short_name="synth",
default=False,
help=help_wrap( help=help_wrap(
"If set, use fake data (zeroes) instead of a real dataset. " "If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes " "This mode is useful for performance debugging, as it removes "
...@@ -132,16 +144,20 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -132,16 +144,20 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if max_train_steps: if max_train_steps:
flags.DEFINE_integer( flags.DEFINE_integer(
name="max_train_steps", short_name="mts", default=None, help=help_wrap( name="max_train_steps",
short_name="mts",
default=None,
help=help_wrap(
"The model will stop training if the global_step reaches this " "The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number " "value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set " "of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag." "--train_epochs=1 when using this flag."))
))
if dtype: if dtype:
flags.DEFINE_enum( flags.DEFINE_enum(
name="dtype", short_name="dt", default="fp32", name="dtype",
short_name="dt",
default="fp32",
enum_values=DTYPE_MAP.keys(), enum_values=DTYPE_MAP.keys(),
help=help_wrap("The TensorFlow datatype used for calculations. " help=help_wrap("The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a " "Variables may be cast to a higher precision on a "
...@@ -155,8 +171,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -155,8 +171,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
"variables. This is mathematically equivalent to training without " "variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate " "a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default " "gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes.{}" "for fp16 is 128 and 1 for all other dtypes.{}")
)
if dynamic_loss_scale: if dynamic_loss_scale:
loss_scale_help_text = loss_scale_help_text.format( loss_scale_help_text = loss_scale_help_text.format(
"This can be an int/float or the string 'dynamic'", "This can be an int/float or the string 'dynamic'",
...@@ -171,11 +186,13 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -171,11 +186,13 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
loss_scale_validation_msg = "loss_scale should be a positive int/float." loss_scale_validation_msg = "loss_scale should be a positive int/float."
if loss_scale: if loss_scale:
flags.DEFINE_string( flags.DEFINE_string(
name="loss_scale", short_name="ls", default=None, name="loss_scale",
short_name="ls",
default=None,
help=help_wrap(loss_scale_help_text)) help=help_wrap(loss_scale_help_text))
@flags.validator(flag_name="loss_scale", @flags.validator(
message=loss_scale_validation_msg) flag_name="loss_scale", message=loss_scale_validation_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid.""" """Validator to check the loss scale flag is valid."""
if loss_scale is None: if loss_scale is None:
...@@ -193,7 +210,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -193,7 +210,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if fp16_implementation: if fp16_implementation:
flags.DEFINE_enum( flags.DEFINE_enum(
name="fp16_implementation", default="keras", name="fp16_implementation",
default="keras",
enum_values=("keras', 'graph_rewrite"), enum_values=("keras', 'graph_rewrite"),
help=help_wrap( help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no " "When --dtype=fp16, how fp16 should be implemented. This has no "
...@@ -202,8 +220,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -202,8 +220,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
"tf.train.experimental.enable_mixed_precision_graph_rewrite " "tf.train.experimental.enable_mixed_precision_graph_rewrite "
"API.")) "API."))
@flags.multi_flags_validator(["fp16_implementation", "dtype", @flags.multi_flags_validator(
"loss_scale"]) ["fp16_implementation", "dtype", "loss_scale"])
def _check_fp16_implementation(flags_dict): def _check_fp16_implementation(flags_dict):
"""Validator to check fp16_implementation flag is valid.""" """Validator to check fp16_implementation flag is valid."""
if (flags_dict["fp16_implementation"] == "graph_rewrite" and if (flags_dict["fp16_implementation"] == "graph_rewrite" and
...@@ -214,7 +232,9 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -214,7 +232,9 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if all_reduce_alg: if all_reduce_alg:
flags.DEFINE_string( flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None, name="all_reduce_alg",
short_name="ara",
default=None,
help=help_wrap("Defines the algorithm to use for performing all-reduce." help=help_wrap("Defines the algorithm to use for performing all-reduce."
"When specified with MirroredStrategy for single " "When specified with MirroredStrategy for single "
"worker, this controls " "worker, this controls "
...@@ -226,24 +246,26 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -226,24 +246,26 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if num_packs: if num_packs:
flags.DEFINE_integer( flags.DEFINE_integer(
name="num_packs", default=1, name="num_packs",
default=1,
help=help_wrap("Sets `num_packs` in the cross device ops used in " help=help_wrap("Sets `num_packs` in the cross device ops used in "
"MirroredStrategy. For details, see " "MirroredStrategy. For details, see "
"tf.distribute.NcclAllReduce.")) "tf.distribute.NcclAllReduce."))
if tf_gpu_thread_mode: if tf_gpu_thread_mode:
flags.DEFINE_string( flags.DEFINE_string(
name="tf_gpu_thread_mode", short_name="gt_mode", default=None, name="tf_gpu_thread_mode",
short_name="gt_mode",
default=None,
help=help_wrap( help=help_wrap(
"Whether and how the GPU device uses its own threadpool.") "Whether and how the GPU device uses its own threadpool."))
)
flags.DEFINE_integer( flags.DEFINE_integer(
name="per_gpu_thread_count", short_name="pgtc", default=0, name="per_gpu_thread_count",
help=help_wrap( short_name="pgtc",
"The number of threads to use for GPU. Only valid when " default=0,
"tf_gpu_thread_mode is not global.") help=help_wrap("The number of threads to use for GPU. Only valid when "
) "tf_gpu_thread_mode is not global."))
if datasets_num_private_threads: if datasets_num_private_threads:
flags.DEFINE_integer( flags.DEFINE_integer(
...@@ -251,8 +273,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -251,8 +273,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
default=None, default=None,
help=help_wrap( help=help_wrap(
"Number of threads for a private threadpool created for all" "Number of threads for a private threadpool created for all"
"datasets computation..") "datasets computation.."))
)
if datasets_num_parallel_batches: if datasets_num_parallel_batches:
flags.DEFINE_integer( flags.DEFINE_integer(
...@@ -260,8 +281,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -260,8 +281,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
default=None, default=None,
help=help_wrap( help=help_wrap(
"Determines how many batches to process in parallel when using " "Determines how many batches to process in parallel when using "
"map and batch from tf.data.") "map and batch from tf.data."))
)
if training_dataset_cache: if training_dataset_cache:
flags.DEFINE_boolean( flags.DEFINE_boolean(
...@@ -270,20 +290,19 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -270,20 +290,19 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
help=help_wrap( help=help_wrap(
"Determines whether to cache the training dataset on workers. " "Determines whether to cache the training dataset on workers. "
"Typically used to improve training performance when training " "Typically used to improve training performance when training "
"data is in remote storage and can fit into worker memory.") "data is in remote storage and can fit into worker memory."))
)
if tf_data_experimental_slack: if tf_data_experimental_slack:
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="tf_data_experimental_slack", name="tf_data_experimental_slack",
default=False, default=False,
help=help_wrap( help=help_wrap(
"Whether to enable tf.data's `experimental_slack` option.") "Whether to enable tf.data's `experimental_slack` option."))
)
if enable_xla: if enable_xla:
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="enable_xla", default=False, name="enable_xla",
default=False,
help="Whether to enable XLA auto jit compilation") help="Whether to enable XLA auto jit compilation")
return key_flags return key_flags
...@@ -22,6 +22,7 @@ from __future__ import division ...@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys import sys
from six.moves import shlex_quote from six.moves import shlex_quote
from absl import app as absl_app from absl import app as absl_app
...@@ -65,6 +66,7 @@ def register_key_flags_in_core(f): ...@@ -65,6 +66,7 @@ def register_key_flags_in_core(f):
def core_fn(*args, **kwargs): def core_fn(*args, **kwargs):
key_flags = f(*args, **kwargs) key_flags = f(*args, **kwargs)
[flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned [flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned
return core_fn return core_fn
...@@ -80,16 +82,15 @@ define_performance = register_key_flags_in_core(_performance.define_performance) ...@@ -80,16 +82,15 @@ define_performance = register_key_flags_in_core(_performance.define_performance)
define_distribution = register_key_flags_in_core( define_distribution = register_key_flags_in_core(
_distribution.define_distribution) _distribution.define_distribution)
help_wrap = _conventions.help_wrap help_wrap = _conventions.help_wrap
get_num_gpus = _base.get_num_gpus get_num_gpus = _base.get_num_gpus
get_tf_dtype = _performance.get_tf_dtype get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale get_loss_scale = _performance.get_loss_scale
DTYPE_MAP = _performance.DTYPE_MAP DTYPE_MAP = _performance.DTYPE_MAP
require_cloud_storage = _device.require_cloud_storage require_cloud_storage = _device.require_cloud_storage
def _get_nondefault_flags_as_dict(): def _get_nondefault_flags_as_dict():
"""Returns the nondefault flags as a dict from flag name to value.""" """Returns the nondefault flags as a dict from flag name to value."""
nondefault_flags = {} nondefault_flags = {}
......
...@@ -22,12 +22,20 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp ...@@ -22,12 +22,20 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags(): def define_flags():
flags_core.define_base(clean=True, num_gpu=False, stop_threshold=True, flags_core.define_base(
hooks=True, train_epochs=True, clean=True,
epochs_between_evals=True) num_gpu=False,
stop_threshold=True,
hooks=True,
train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True, num_parallel_calls=True,
dynamic_loss_scale=True, loss_scale=True, synthetic_data=True, inter_op=True,
intra_op=True,
dynamic_loss_scale=True,
loss_scale=True,
synthetic_data=True,
dtype=True) dtype=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
...@@ -41,8 +49,7 @@ class BaseTester(unittest.TestCase): ...@@ -41,8 +49,7 @@ class BaseTester(unittest.TestCase):
define_flags() define_flags()
def test_default_setting(self): def test_default_setting(self):
"""Test to ensure fields exist and defaults can be set. """Test to ensure fields exist and defaults can be set."""
"""
defaults = dict( defaults = dict(
data_dir="dfgasf", data_dir="dfgasf",
...@@ -54,8 +61,7 @@ class BaseTester(unittest.TestCase): ...@@ -54,8 +61,7 @@ class BaseTester(unittest.TestCase):
num_parallel_calls=18, num_parallel_calls=18,
inter_op_parallelism_threads=5, inter_op_parallelism_threads=5,
intra_op_parallelism_threads=10, intra_op_parallelism_threads=10,
data_format="channels_first" data_format="channels_first")
)
flags_core.set_defaults(**defaults) flags_core.set_defaults(**defaults)
flags_core.parse_flags() flags_core.parse_flags()
...@@ -77,8 +83,7 @@ class BaseTester(unittest.TestCase): ...@@ -77,8 +83,7 @@ class BaseTester(unittest.TestCase):
assert flags.FLAGS.get_flag_value(name=key, default=None) == value assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_booleans(self): def test_booleans(self):
"""Test to ensure boolean flags trigger as expected. """Test to ensure boolean flags trigger as expected."""
"""
flags_core.parse_flags([__file__, "--use_synthetic_data"]) flags_core.parse_flags([__file__, "--use_synthetic_data"])
...@@ -87,35 +92,33 @@ class BaseTester(unittest.TestCase): ...@@ -87,35 +92,33 @@ class BaseTester(unittest.TestCase):
def test_parse_dtype_info(self): def test_parse_dtype_info(self):
flags_core.parse_flags([__file__, "--dtype", "fp16"]) flags_core.parse_flags([__file__, "--dtype", "fp16"])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS, self.assertEqual(
default_for_fp16=2), 2) flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 2)
flags_core.parse_flags( flags_core.parse_flags([__file__, "--dtype", "fp16", "--loss_scale", "5"])
[__file__, "--dtype", "fp16", "--loss_scale", "5"]) self.assertEqual(
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS, flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5)
default_for_fp16=2), 5)
flags_core.parse_flags( flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "dynamic"]) [__file__, "--dtype", "fp16", "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS, self.assertEqual(
default_for_fp16=2), "dynamic") flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), "dynamic")
flags_core.parse_flags([__file__, "--dtype", "fp32"]) flags_core.parse_flags([__file__, "--dtype", "fp32"])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS, self.assertEqual(
default_for_fp16=2), 1) flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 1)
flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"]) flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS, self.assertEqual(
default_for_fp16=2), 5) flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5)
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "int8"]) flags_core.parse_flags([__file__, "--dtype", "int8"])
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "fp16", flags_core.parse_flags(
"--loss_scale", "abc"]) [__file__, "--dtype", "fp16", "--loss_scale", "abc"])
def test_get_nondefault_flags_as_str(self): def test_get_nondefault_flags_as_str(self):
defaults = dict( defaults = dict(
...@@ -123,8 +126,7 @@ class BaseTester(unittest.TestCase): ...@@ -123,8 +126,7 @@ class BaseTester(unittest.TestCase):
data_dir="abc", data_dir="abc",
hooks=["LoggingTensorHook"], hooks=["LoggingTensorHook"],
stop_threshold=1.5, stop_threshold=1.5,
use_synthetic_data=False use_synthetic_data=False)
)
flags_core.set_defaults(**defaults) flags_core.set_defaults(**defaults)
flags_core.parse_flags() flags_core.parse_flags()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment