Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
......@@ -41,7 +41,6 @@ from official.recommendation import popen_helper
from official.recommendation import stat_utils
from tensorflow.python.tpu.datasets import StreamingFilesDataset
SUMMARY_TEMPLATE = """General:
{spacer}Num users: {num_users}
{spacer}Num items: {num_items}
......@@ -74,25 +73,27 @@ class DatasetManager(object):
num_train_epochs=None):
# type: (bool, bool, int, typing.Optional[str], bool, int) -> None
"""Constructs a `DatasetManager` instance.
Args:
is_training: Boolean of whether the data provided is training or
evaluation data. This determines whether to reuse the data
(if is_training=False) and the exact structure to use when storing and
evaluation data. This determines whether to reuse the data (if
is_training=False) and the exact structure to use when storing and
yielding data.
stream_files: Boolean indicating whether data should be serialized and
written to file shards.
batches_per_epoch: The number of batches in a single epoch.
shard_root: The base directory to be used when stream_files=True.
deterministic: Forgo non-deterministic speedups. (i.e. sloppy=True)
num_train_epochs: Number of epochs to generate. If None, then each
call to `get_dataset()` increments the number of epochs requested.
num_train_epochs: Number of epochs to generate. If None, then each call to
`get_dataset()` increments the number of epochs requested.
"""
self._is_training = is_training
self._deterministic = deterministic
self._stream_files = stream_files
self._writers = []
self._write_locks = [threading.RLock() for _ in
range(rconst.NUM_FILE_SHARDS)] if stream_files else []
self._write_locks = [
threading.RLock() for _ in range(rconst.NUM_FILE_SHARDS)
] if stream_files else []
self._batches_per_epoch = batches_per_epoch
self._epochs_completed = 0
self._epochs_requested = num_train_epochs if num_train_epochs else 0
......@@ -103,8 +104,9 @@ class DatasetManager(object):
@property
def current_data_root(self):
subdir = (rconst.TRAIN_FOLDER_TEMPLATE.format(self._epochs_completed)
if self._is_training else rconst.EVAL_FOLDER)
subdir = (
rconst.TRAIN_FOLDER_TEMPLATE.format(self._epochs_completed)
if self._is_training else rconst.EVAL_FOLDER)
return os.path.join(self._shard_root, subdir)
def buffer_reached(self):
......@@ -123,8 +125,8 @@ class DatasetManager(object):
k: create_int_feature(v.astype(np.int64)) for k, v in data.items()
}
return tf.train.Example(
features=tf.train.Features(feature=feature_dict)).SerializeToString()
return tf.train.Example(features=tf.train.Features(
feature=feature_dict)).SerializeToString()
@staticmethod
def deserialize(serialized_data, batch_size=None, is_training=True):
......@@ -134,8 +136,8 @@ class DatasetManager(object):
serialized_data: A tensor containing serialized records.
batch_size: The data arrives pre-batched, so batch size is needed to
deserialize the data.
is_training: Boolean, whether data to deserialize to training data
or evaluation data.
is_training: Boolean, whether data to deserialize to training data or
evaluation data.
"""
def _get_feature_map(batch_size, is_training=True):
......@@ -171,13 +173,16 @@ class DatasetManager(object):
valid_point_mask = tf.cast(features[rconst.VALID_POINT_MASK], tf.bool)
fake_dup_mask = tf.zeros_like(users)
return {
movielens.USER_COLUMN: users,
movielens.ITEM_COLUMN: items,
rconst.VALID_POINT_MASK: valid_point_mask,
movielens.USER_COLUMN:
users,
movielens.ITEM_COLUMN:
items,
rconst.VALID_POINT_MASK:
valid_point_mask,
rconst.TRAIN_LABEL_KEY:
tf.reshape(tf.cast(features["labels"], tf.bool),
(batch_size, 1)),
rconst.DUPLICATE_MASK: fake_dup_mask
tf.reshape(tf.cast(features["labels"], tf.bool), (batch_size, 1)),
rconst.DUPLICATE_MASK:
fake_dup_mask
}
else:
labels = tf.cast(tf.zeros_like(users), tf.bool)
......@@ -228,8 +233,10 @@ class DatasetManager(object):
if self._stream_files:
tf.io.gfile.makedirs(self.current_data_root)
template = os.path.join(self.current_data_root, rconst.SHARD_TEMPLATE)
self._writers = [tf.io.TFRecordWriter(template.format(i))
for i in range(rconst.NUM_FILE_SHARDS)]
self._writers = [
tf.io.TFRecordWriter(template.format(i))
for i in range(rconst.NUM_FILE_SHARDS)
]
def end_construction(self):
if self._stream_files:
......@@ -273,8 +280,8 @@ class DatasetManager(object):
Args:
batch_size: The per-replica batch size of the dataset.
epochs_between_evals: How many epochs worth of data to yield.
(Generator mode only.)
epochs_between_evals: How many epochs worth of data to yield. (Generator
mode only.)
"""
self.increment_request_epoch()
if self._stream_files:
......@@ -285,11 +292,13 @@ class DatasetManager(object):
if not self._is_training:
self._result_queue.put(epoch_data_dir) # Eval data is reused.
file_pattern = os.path.join(
epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
file_pattern = os.path.join(epoch_data_dir,
rconst.SHARD_TEMPLATE.format("*"))
dataset = StreamingFilesDataset(
files=file_pattern, worker_job=popen_helper.worker_job(),
num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
files=file_pattern,
worker_job=popen_helper.worker_job(),
num_parallel_reads=rconst.NUM_FILE_SHARDS,
num_epochs=1,
sloppy=not self._deterministic)
map_fn = functools.partial(
self.deserialize,
......@@ -298,8 +307,10 @@ class DatasetManager(object):
dataset = dataset.map(map_fn, num_parallel_calls=16)
else:
types = {movielens.USER_COLUMN: rconst.USER_DTYPE,
movielens.ITEM_COLUMN: rconst.ITEM_DTYPE}
types = {
movielens.USER_COLUMN: rconst.USER_DTYPE,
movielens.ITEM_COLUMN: rconst.ITEM_DTYPE
}
shapes = {
movielens.USER_COLUMN: tf.TensorShape([batch_size, 1]),
movielens.ITEM_COLUMN: tf.TensorShape([batch_size, 1])
......@@ -319,8 +330,7 @@ class DatasetManager(object):
data_generator = functools.partial(
self.data_generator, epochs_between_evals=epochs_between_evals)
dataset = tf.data.Dataset.from_generator(
generator=data_generator, output_types=types,
output_shapes=shapes)
generator=data_generator, output_types=types, output_shapes=shapes)
return dataset.prefetch(16)
......@@ -332,16 +342,17 @@ class DatasetManager(object):
# Estimator passes batch_size during training and eval_batch_size during
# eval.
param_batch_size = (params["batch_size"] if self._is_training else
params.get("eval_batch_size") or params["batch_size"])
param_batch_size = (
params["batch_size"] if self._is_training else
params.get("eval_batch_size") or params["batch_size"])
if batch_size != param_batch_size:
raise ValueError("producer batch size ({}) differs from params batch "
"size ({})".format(batch_size, param_batch_size))
epochs_between_evals = (params.get("epochs_between_evals", 1)
if self._is_training else 1)
return self.get_dataset(batch_size=batch_size,
epochs_between_evals=epochs_between_evals)
epochs_between_evals = (
params.get("epochs_between_evals", 1) if self._is_training else 1)
return self.get_dataset(
batch_size=batch_size, epochs_between_evals=epochs_between_evals)
return input_fn
......@@ -405,15 +416,16 @@ class BaseDataConstructor(threading.Thread):
(self._train_pos_count,) = self._train_pos_users.shape
self._elements_in_epoch = (1 + num_train_negatives) * self._train_pos_count
self.train_batches_per_epoch = self._count_batches(
self._elements_in_epoch, train_batch_size, batches_per_train_step)
self.train_batches_per_epoch = self._count_batches(self._elements_in_epoch,
train_batch_size,
batches_per_train_step)
# Evaluation
if eval_batch_size % (1 + rconst.NUM_EVAL_NEGATIVES):
raise ValueError("Eval batch size {} is not divisible by {}".format(
eval_batch_size, 1 + rconst.NUM_EVAL_NEGATIVES))
self._eval_users_per_batch = int(
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
self._eval_users_per_batch = int(eval_batch_size //
(1 + rconst.NUM_EVAL_NEGATIVES))
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
self.eval_batches_per_epoch = self._count_batches(
self._eval_elements_in_epoch, eval_batch_size, batches_per_eval_step)
......@@ -450,12 +462,16 @@ class BaseDataConstructor(threading.Thread):
multiplier = ("(x{} devices)".format(self._batches_per_train_step)
if self._batches_per_train_step > 1 else "")
summary = SUMMARY_TEMPLATE.format(
spacer=" ", num_users=self._num_users, num_items=self._num_items,
spacer=" ",
num_users=self._num_users,
num_items=self._num_items,
train_pos_ct=self._train_pos_count,
train_batch_size=self.train_batch_size,
train_batch_ct=self.train_batches_per_epoch,
eval_pos_ct=self._num_users, eval_batch_size=self.eval_batch_size,
eval_batch_ct=self.eval_batches_per_epoch, multiplier=multiplier)
eval_pos_ct=self._num_users,
eval_batch_size=self.eval_batch_size,
eval_batch_ct=self.eval_batches_per_epoch,
multiplier=multiplier)
return super(BaseDataConstructor, self).__str__() + "\n" + summary
@staticmethod
......@@ -514,8 +530,9 @@ class BaseDataConstructor(threading.Thread):
i: The index of the batch. This is used when stream_files=True to assign
data to file shards.
"""
batch_indices = self._current_epoch_order[i * self.train_batch_size:
(i + 1) * self.train_batch_size]
batch_indices = self._current_epoch_order[i *
self.train_batch_size:(i + 1) *
self.train_batch_size]
(mask_start_index,) = batch_indices.shape
batch_ind_mod = np.mod(batch_indices, self._train_pos_count)
......@@ -578,8 +595,9 @@ class BaseDataConstructor(threading.Thread):
map_args = list(range(self.train_batches_per_epoch))
self._current_epoch_order = next(self._shuffle_iterator)
get_pool = (popen_helper.get_fauxpool if self.deterministic else
popen_helper.get_threadpool)
get_pool = (
popen_helper.get_fauxpool
if self.deterministic else popen_helper.get_threadpool)
with get_pool(6) as pool:
pool.map(self._get_training_batch, map_args)
self._train_dataset.end_construction()
......@@ -602,8 +620,8 @@ class BaseDataConstructor(threading.Thread):
users: An array of users in a batch. (should be identical along axis 1)
positive_items: An array (batch_size x 1) of positive item indices.
negative_items: An array of negative item indices.
users_per_batch: How many users should be in the batch. This is passed
as an argument so that ncf_test.py can use this method.
users_per_batch: How many users should be in the batch. This is passed as
an argument so that ncf_test.py can use this method.
Returns:
User, item, and duplicate_mask arrays.
......@@ -635,11 +653,14 @@ class BaseDataConstructor(threading.Thread):
"""
low_index = i * self._eval_users_per_batch
high_index = (i + 1) * self._eval_users_per_batch
users = np.repeat(self._eval_pos_users[low_index:high_index, np.newaxis],
1 + rconst.NUM_EVAL_NEGATIVES, axis=1)
users = np.repeat(
self._eval_pos_users[low_index:high_index, np.newaxis],
1 + rconst.NUM_EVAL_NEGATIVES,
axis=1)
positive_items = self._eval_pos_items[low_index:high_index, np.newaxis]
negative_items = (self.lookup_negative_items(negative_users=users[:, :-1])
.reshape(-1, rconst.NUM_EVAL_NEGATIVES))
negative_items = (
self.lookup_negative_items(negative_users=users[:, :-1]).reshape(
-1, rconst.NUM_EVAL_NEGATIVES))
users, items, duplicate_mask = self._assemble_eval_batch(
users, positive_items, negative_items, self._eval_users_per_batch)
......@@ -664,8 +685,9 @@ class BaseDataConstructor(threading.Thread):
self._eval_dataset.start_construction()
map_args = [i for i in range(self.eval_batches_per_epoch)]
get_pool = (popen_helper.get_fauxpool if self.deterministic else
popen_helper.get_threadpool)
get_pool = (
popen_helper.get_fauxpool
if self.deterministic else popen_helper.get_threadpool)
with get_pool(6) as pool:
pool.map(self._get_eval_batch, map_args)
self._eval_dataset.end_construction()
......@@ -677,12 +699,12 @@ class BaseDataConstructor(threading.Thread):
# It isn't feasible to provide a foolproof check, so this is designed to
# catch most failures rather than provide an exhaustive guard.
if self._fatal_exception is not None:
raise ValueError("Fatal exception in the data production loop: {}"
.format(self._fatal_exception))
raise ValueError("Fatal exception in the data production loop: {}".format(
self._fatal_exception))
return (
self._train_dataset.make_input_fn(self.train_batch_size) if is_training
else self._eval_dataset.make_input_fn(self.eval_batch_size))
return (self._train_dataset.make_input_fn(self.train_batch_size)
if is_training else self._eval_dataset.make_input_fn(
self.eval_batch_size))
def increment_request_epoch(self):
self._train_dataset.increment_request_epoch()
......@@ -714,8 +736,9 @@ class DummyConstructor(threading.Thread):
# Estimator passes batch_size during training and eval_batch_size during
# eval.
batch_size = (params["batch_size"] if is_training else
params.get("eval_batch_size") or params["batch_size"])
batch_size = (
params["batch_size"] if is_training else
params.get("eval_batch_size") or params["batch_size"])
num_users = params["num_users"]
num_items = params["num_items"]
......@@ -795,6 +818,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
a pre-compute which is quadratic in problem size will still fit in memory. A
more scalable lookup method is in the works.
"""
def __init__(self, *args, **kwargs):
super(MaterializedDataConstructor, self).__init__(*args, **kwargs)
self._negative_table = None
......@@ -807,8 +831,8 @@ class MaterializedDataConstructor(BaseDataConstructor):
self._train_pos_users[:-1])[:, 0] + 1
(upper_bound,) = self._train_pos_users.shape
index_bounds = [0] + inner_bounds.tolist() + [upper_bound]
self._negative_table = np.zeros(shape=(self._num_users, self._num_items),
dtype=rconst.ITEM_DTYPE)
self._negative_table = np.zeros(
shape=(self._num_users, self._num_items), dtype=rconst.ITEM_DTYPE)
# Set the table to the max value to make sure the embedding lookup will fail
# if we go out of bounds, rather than just overloading item zero.
......@@ -825,7 +849,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
# call does not parallelize well. Multiprocessing incurs too much
# serialization overhead to be worthwhile.
for i in range(self._num_users):
positives = self._train_pos_items[index_bounds[i]:index_bounds[i+1]]
positives = self._train_pos_items[index_bounds[i]:index_bounds[i + 1]]
negatives = np.delete(full_set, positives)
self._per_user_neg_count[i] = self._num_items - positives.shape[0]
self._negative_table[i, :self._per_user_neg_count[i]] = negatives
......@@ -848,6 +872,7 @@ class BisectionDataConstructor(BaseDataConstructor):
it at which point the item id for the ith negative is a simply algebraic
expression.
"""
def __init__(self, *args, **kwargs):
super(BisectionDataConstructor, self).__init__(*args, **kwargs)
self.index_bounds = None
......@@ -855,7 +880,7 @@ class BisectionDataConstructor(BaseDataConstructor):
self._total_negatives = None
def _index_segment(self, user):
lower, upper = self.index_bounds[user:user+2]
lower, upper = self.index_bounds[user:user + 2]
items = self._sorted_train_pos_items[lower:upper]
negatives_since_last_positive = np.concatenate(
......@@ -877,11 +902,11 @@ class BisectionDataConstructor(BaseDataConstructor):
self._sorted_train_pos_items = self._train_pos_items.copy()
for i in range(self._num_users):
lower, upper = self.index_bounds[i:i+2]
lower, upper = self.index_bounds[i:i + 2]
self._sorted_train_pos_items[lower:upper].sort()
self._total_negatives = np.concatenate([
self._index_segment(i) for i in range(self._num_users)])
self._total_negatives = np.concatenate(
[self._index_segment(i) for i in range(self._num_users)])
logging.info("Negative total vector built. Time: {:.1f} seconds".format(
timeit.default_timer() - start_time))
......@@ -912,8 +937,7 @@ class BisectionDataConstructor(BaseDataConstructor):
use_shortcut = neg_item_choice >= self._total_negatives[right_index]
output[use_shortcut] = (
self._sorted_train_pos_items[right_index] + 1 +
(neg_item_choice - self._total_negatives[right_index])
)[use_shortcut]
(neg_item_choice - self._total_negatives[right_index]))[use_shortcut]
if np.all(use_shortcut):
# The bisection code is ill-posed when there are no elements.
......@@ -943,8 +967,7 @@ class BisectionDataConstructor(BaseDataConstructor):
output[not_use_shortcut] = (
self._sorted_train_pos_items[right_index] -
(self._total_negatives[right_index] - neg_item_choice)
)
(self._total_negatives[right_index] - neg_item_choice))
assert np.all(output >= 0)
......
......@@ -25,6 +25,7 @@ import time
import timeit
# pylint: disable=wrong-import-order
from absl import logging
import numpy as np
import pandas as pd
......@@ -37,10 +38,9 @@ from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import movielens
_EXPECTED_CACHE_KEYS = (
rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
_EXPECTED_CACHE_KEYS = (rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY,
rconst.EVAL_USER_KEY, rconst.EVAL_ITEM_KEY,
rconst.USER_MAP, rconst.ITEM_MAP)
def read_dataframe(
......@@ -178,17 +178,20 @@ def _filter_index_sort(raw_rating_path: Text,
eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
data = {
rconst.TRAIN_USER_KEY: train_df[movielens.USER_COLUMN]
.values.astype(rconst.USER_DTYPE),
rconst.TRAIN_ITEM_KEY: train_df[movielens.ITEM_COLUMN]
.values.astype(rconst.ITEM_DTYPE),
rconst.EVAL_USER_KEY: eval_df[movielens.USER_COLUMN]
.values.astype(rconst.USER_DTYPE),
rconst.EVAL_ITEM_KEY: eval_df[movielens.ITEM_COLUMN]
.values.astype(rconst.ITEM_DTYPE),
rconst.USER_MAP: user_map,
rconst.ITEM_MAP: item_map,
"create_time": time.time(),
rconst.TRAIN_USER_KEY:
train_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
rconst.TRAIN_ITEM_KEY:
train_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
rconst.EVAL_USER_KEY:
eval_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
rconst.EVAL_ITEM_KEY:
eval_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
rconst.USER_MAP:
user_map,
rconst.ITEM_MAP:
item_map,
"create_time":
time.time(),
}
logging.info("Writing raw data cache.")
......@@ -217,8 +220,8 @@ def instantiate_pipeline(dataset,
for the input pipeline.
deterministic: Tell the data constructor to produce deterministically.
epoch_dir: Directory in which to store the training epochs.
generate_data_offline: Boolean, whether current pipeline is done offline
or while training.
generate_data_offline: Boolean, whether current pipeline is done offline or
while training.
"""
logging.info("Beginning data preprocessing.")
......@@ -258,8 +261,8 @@ def instantiate_pipeline(dataset,
create_data_offline=generate_data_offline)
run_time = timeit.default_timer() - st
logging.info("Data preprocessing complete. Time: {:.1f} sec."
.format(run_time))
logging.info(
"Data preprocessing complete. Time: {:.1f} sec.".format(run_time))
print(producer)
return num_users, num_items, producer
......@@ -23,6 +23,7 @@ import hashlib
import os
import mock
import numpy as np
import scipy.stats
import tensorflow as tf
......@@ -32,7 +33,6 @@ from official.recommendation import data_preprocessing
from official.recommendation import movielens
from official.recommendation import popen_helper
DATASET = "ml-test"
NUM_USERS = 1000
NUM_ITEMS = 2000
......@@ -41,7 +41,6 @@ BATCH_SIZE = 2048
EVAL_BATCH_SIZE = 4000
NUM_NEG = 4
END_TO_END_TRAIN_MD5 = "b218738e915e825d03939c5e305a2698"
END_TO_END_EVAL_MD5 = "d753d0f3186831466d6e218163a9501e"
FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22"
......@@ -136,8 +135,11 @@ class BaseTest(tf.test.TestCase):
def _test_end_to_end(self, constructor_type):
params = self.make_params(train_epochs=1)
_, _, producer = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, params=params,
constructor_type=constructor_type, deterministic=True)
dataset=DATASET,
data_dir=self.temp_data_dir,
params=params,
constructor_type=constructor_type,
deterministic=True)
producer.start()
producer.join()
......@@ -258,8 +260,11 @@ class BaseTest(tf.test.TestCase):
train_epochs = 5
params = self.make_params(train_epochs=train_epochs)
_, _, producer = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, params=params,
constructor_type=constructor_type, deterministic=True)
dataset=DATASET,
data_dir=self.temp_data_dir,
params=params,
constructor_type=constructor_type,
deterministic=True)
producer.start()
......@@ -298,8 +303,8 @@ class BaseTest(tf.test.TestCase):
self.assertRegexpMatches(md5.hexdigest(), FRESH_RANDOMNESS_MD5)
# The positive examples should appear exactly once each epoch
self.assertAllEqual(list(positive_counts.values()),
[train_epochs for _ in positive_counts])
self.assertAllEqual(
list(positive_counts.values()), [train_epochs for _ in positive_counts])
# The threshold for the negatives is heuristic, but in general repeats are
# expected, but should not appear too frequently.
......@@ -317,8 +322,8 @@ class BaseTest(tf.test.TestCase):
# The frequency of occurance of a given negative pair should follow an
# approximately binomial distribution in the limit that the cardinality of
# the negative pair set >> number of samples per epoch.
approx_pdf = scipy.stats.binom.pmf(k=np.arange(train_epochs+1),
n=train_epochs, p=e_sample)
approx_pdf = scipy.stats.binom.pmf(
k=np.arange(train_epochs + 1), n=train_epochs, p=e_sample)
# Tally the actual observed counts.
count_distribution = [0 for _ in range(train_epochs + 1)]
......
......@@ -27,6 +27,7 @@ import tempfile
import zipfile
# pylint: disable=g-bad-import-order
# Import libraries
import numpy as np
import pandas as pd
import six
......
......@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common functionalities used by both Keras and Estimator implementations.
"""
"""Common functionalities used by both Keras and Estimator implementations."""
from __future__ import absolute_import
from __future__ import division
......@@ -23,6 +22,7 @@ import json
import os
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
from absl import logging
......@@ -56,7 +56,9 @@ def get_inputs(params):
num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
else:
num_users, num_items, producer = data_preprocessing.instantiate_pipeline(
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params,
dataset=FLAGS.dataset,
data_dir=FLAGS.data_dir,
params=params,
constructor_type=FLAGS.constructor_type,
deterministic=FLAGS.seed is not None)
num_train_steps = producer.train_batches_per_epoch
......@@ -108,16 +110,17 @@ def get_v1_distribution_strategy(params):
"""Returns the distribution strategy to use."""
if params["use_tpu"]:
# Some of the networking libraries are quite chatty.
for name in ["googleapiclient.discovery", "googleapiclient.discovery_cache",
"oauth2client.transport"]:
for name in [
"googleapiclient.discovery", "googleapiclient.discovery_cache",
"oauth2client.transport"
]:
logging.getLogger(name).setLevel(logging.ERROR)
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=params["tpu"],
zone=params["tpu_zone"],
project=params["tpu_gcp_project"],
coordinator_name="coordinator"
)
coordinator_name="coordinator")
logging.info("Issuing reset command to TPU to ensure a clean state.")
tf.Session.reset(tpu_cluster_resolver.get_master())
......@@ -126,10 +129,12 @@ def get_v1_distribution_strategy(params):
# by reading the `TF_CONFIG` environment variable, and the coordinator
# is used by StreamingFilesDataset.
tf_config_env = {
"session_master": tpu_cluster_resolver.get_master(),
"eval_session_master": tpu_cluster_resolver.get_master(),
"coordinator": tpu_cluster_resolver.cluster_spec()
.as_dict()["coordinator"]
"session_master":
tpu_cluster_resolver.get_master(),
"eval_session_master":
tpu_cluster_resolver.get_master(),
"coordinator":
tpu_cluster_resolver.cluster_spec().as_dict()["coordinator"]
}
os.environ["TF_CONFIG"] = json.dumps(tf_config_env)
......@@ -146,10 +151,16 @@ def get_v1_distribution_strategy(params):
def define_ncf_flags():
"""Add flags for running ncf_main."""
# Add common flags
flags_core.define_base(model_dir=True, clean=True, train_epochs=True,
epochs_between_evals=True, export_dir=False,
run_eagerly=True, stop_threshold=True, num_gpu=True,
distribution_strategy=True)
flags_core.define_base(
model_dir=True,
clean=True,
train_epochs=True,
epochs_between_evals=True,
export_dir=False,
run_eagerly=True,
stop_threshold=True,
num_gpu=True,
distribution_strategy=True)
flags_core.define_performance(
synthetic_data=True,
dtype=True,
......@@ -171,69 +182,82 @@ def define_ncf_flags():
dataset=movielens.ML_1M,
train_epochs=2,
batch_size=99000,
tpu=None
)
tpu=None)
# Add ncf-specific flags
flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap(
name="download_if_missing",
default=True,
help=flags_core.help_wrap(
"Download data to data_dir if it is not already present."))
flags.DEFINE_integer(
name="eval_batch_size", default=None, help=flags_core.help_wrap(
name="eval_batch_size",
default=None,
help=flags_core.help_wrap(
"The batch size used for evaluation. This should generally be larger"
"than the training batch size as the lack of back propagation during"
"evaluation can allow for larger batch sizes to fit in memory. If not"
"specified, the training batch size (--batch_size) will be used."))
flags.DEFINE_integer(
name="num_factors", default=8,
name="num_factors",
default=8,
help=flags_core.help_wrap("The Embedding size of MF model."))
# Set the default as a list of strings to be consistent with input arguments
flags.DEFINE_list(
name="layers", default=["64", "32", "16", "8"],
name="layers",
default=["64", "32", "16", "8"],
help=flags_core.help_wrap(
"The sizes of hidden layers for MLP. Example "
"to specify different sizes of MLP layers: --layers=32,16,8,4"))
flags.DEFINE_float(
name="mf_regularization", default=0.,
name="mf_regularization",
default=0.,
help=flags_core.help_wrap(
"The regularization factor for MF embeddings. The factor is used by "
"regularizer which allows to apply penalties on layer parameters or "
"layer activity during optimization."))
flags.DEFINE_list(
name="mlp_regularization", default=["0.", "0.", "0.", "0."],
name="mlp_regularization",
default=["0.", "0.", "0.", "0."],
help=flags_core.help_wrap(
"The regularization factor for each MLP layer. See mf_regularization "
"help for more info about regularization factor."))
flags.DEFINE_integer(
name="num_neg", default=4,
name="num_neg",
default=4,
help=flags_core.help_wrap(
"The Number of negative instances to pair with a positive instance."))
flags.DEFINE_float(
name="learning_rate", default=0.001,
name="learning_rate",
default=0.001,
help=flags_core.help_wrap("The learning rate."))
flags.DEFINE_float(
name="beta1", default=0.9,
name="beta1",
default=0.9,
help=flags_core.help_wrap("beta1 hyperparameter for the Adam optimizer."))
flags.DEFINE_float(
name="beta2", default=0.999,
name="beta2",
default=0.999,
help=flags_core.help_wrap("beta2 hyperparameter for the Adam optimizer."))
flags.DEFINE_float(
name="epsilon", default=1e-8,
name="epsilon",
default=1e-8,
help=flags_core.help_wrap("epsilon hyperparameter for the Adam "
"optimizer."))
flags.DEFINE_float(
name="hr_threshold", default=1.0,
name="hr_threshold",
default=1.0,
help=flags_core.help_wrap(
"If passed, training will stop when the evaluation metric HR is "
"greater than or equal to hr_threshold. For dataset ml-1m, the "
......@@ -242,8 +266,10 @@ def define_ncf_flags():
"achieved by MLPerf implementation."))
flags.DEFINE_enum(
name="constructor_type", default="bisection",
enum_values=["bisection", "materialized"], case_sensitive=False,
name="constructor_type",
default="bisection",
enum_values=["bisection", "materialized"],
case_sensitive=False,
help=flags_core.help_wrap(
"Strategy to use for generating false negatives. materialized has a"
"precompute that scales badly, but a faster per-epoch construction"
......@@ -265,7 +291,8 @@ def define_ncf_flags():
help=flags_core.help_wrap("Path to input meta data file."))
flags.DEFINE_bool(
name="ml_perf", default=False,
name="ml_perf",
default=False,
help=flags_core.help_wrap(
"If set, changes the behavior of the model slightly to match the "
"MLPerf reference implementations here: \n"
......@@ -280,23 +307,26 @@ def define_ncf_flags():
"not stable."))
flags.DEFINE_bool(
name="output_ml_perf_compliance_logging", default=False,
name="output_ml_perf_compliance_logging",
default=False,
help=flags_core.help_wrap(
"If set, output the MLPerf compliance logging. This is only useful "
"if one is running the model for MLPerf. See "
"https://github.com/mlperf/policies/blob/master/training_rules.adoc"
"#submission-compliance-logs for details. This uses sudo and so may "
"ask for your password, as root access is needed to clear the system "
"caches, which is required for MLPerf compliance."
)
)
"caches, which is required for MLPerf compliance."))
flags.DEFINE_integer(
name="seed", default=None, help=flags_core.help_wrap(
name="seed",
default=None,
help=flags_core.help_wrap(
"This value will be used to seed both NumPy and TensorFlow."))
@flags.validator("eval_batch_size", "eval_batch_size must be at least {}"
.format(rconst.NUM_EVAL_NEGATIVES + 1))
@flags.validator(
"eval_batch_size",
"eval_batch_size must be at least {}".format(rconst.NUM_EVAL_NEGATIVES +
1))
def eval_size_check(eval_batch_size):
return (eval_batch_size is None or
int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES)
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import functools
# pylint: disable=g-bad-import-order
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
......@@ -130,8 +131,8 @@ def create_ncf_input_data(params,
from tf record files. Must be specified when params["train_input_dataset"]
is specified.
strategy: Distribution strategy used for distributed training. If specified,
used to assert that evaluation batch size is correctly a multiple of
total number of devices used.
used to assert that evaluation batch size is correctly a multiple of total
number of devices used.
Returns:
(training dataset, evaluation dataset, train steps per epoch,
......
......@@ -26,6 +26,7 @@ import json
import os
# pylint: disable=g-bad-import-order
from absl import app
from absl import flags
from absl import logging
......@@ -42,7 +43,6 @@ from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
FLAGS = flags.FLAGS
......@@ -50,9 +50,7 @@ def metric_fn(logits, dup_mask, match_mlperf):
dup_mask = tf.cast(dup_mask, tf.float32)
logits = tf.slice(logits, [0, 1], [-1, -1])
in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
logits,
dup_mask,
match_mlperf)
logits, dup_mask, match_mlperf)
metric_weights = tf.cast(metric_weights, tf.float32)
return in_top_k, metric_weights
......@@ -152,9 +150,10 @@ class CustomEarlyStopping(tf.keras.callbacks.Callback):
logs = logs or {}
monitor_value = logs.get(self.monitor)
if monitor_value is None:
logging.warning("Early stopping conditioned on metric `%s` "
"which is not available. Available metrics are: %s",
self.monitor, ",".join(list(logs.keys())))
logging.warning(
"Early stopping conditioned on metric `%s` "
"which is not available. Available metrics are: %s", self.monitor,
",".join(list(logs.keys())))
return monitor_value
......@@ -181,12 +180,9 @@ def _get_keras_model(params):
logits = base_model.output
zeros = tf.keras.layers.Lambda(
lambda x: x * 0)(logits)
zeros = tf.keras.layers.Lambda(lambda x: x * 0)(logits)
softmax_logits = tf.keras.layers.concatenate(
[zeros, logits],
axis=-1)
softmax_logits = tf.keras.layers.concatenate([zeros, logits], axis=-1)
# Custom training loop calculates loss and metric as a part of
# training/evaluation step function.
......@@ -204,7 +200,8 @@ def _get_keras_model(params):
movielens.ITEM_COLUMN: item_input,
rconst.VALID_POINT_MASK: valid_pt_mask_input,
rconst.DUPLICATE_MASK: dup_mask_input,
rconst.TRAIN_LABEL_KEY: label_input},
rconst.TRAIN_LABEL_KEY: label_input
},
outputs=softmax_logits)
keras_model.summary()
......@@ -412,8 +409,7 @@ def run_ncf_custom_training(params,
optimizer.apply_gradients(grads)
return loss
per_replica_losses = strategy.run(
step_fn, args=(next(train_iterator),))
per_replica_losses = strategy.run(step_fn, args=(next(train_iterator),))
mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss
......@@ -432,8 +428,7 @@ def run_ncf_custom_training(params,
return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = (
strategy.run(
step_fn, args=(next(eval_iterator),)))
strategy.run(step_fn, args=(next(eval_iterator),)))
hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
hr_count = strategy.reduce(
......@@ -482,8 +477,8 @@ def run_ncf_custom_training(params,
# Write train loss once in every 1000 steps.
if train_summary_writer and step % 1000 == 0:
with train_summary_writer.as_default():
tf.summary.scalar("training_loss", train_loss/(step + 1),
step=current_step)
tf.summary.scalar(
"training_loss", train_loss / (step + 1), step=current_step)
for c in callbacks:
c.on_batch_end(current_step)
......@@ -552,7 +547,7 @@ def build_stats(loss, eval_result, time_callback):
if len(timestamp_log) > 1:
stats["avg_exp_per_second"] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log)-1) /
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats
......
......@@ -48,64 +48,68 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
ncf_keras_main.main,
tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
ncf_keras_main.main,
tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_dist_strat_ctl(self):
flags = (self._BASE_END_TO_END_FLAGS +
['-num_gpus', '0'] +
['-keras_use_ctl', 'True'])
flags = (
self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'] +
['-keras_use_ctl', 'True'])
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags)
ncf_keras_main.main, tmp_root=self.get_temp_dir(), extra_flags=flags)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
if context.num_gpus() < 1:
self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available".
format(1, context.num_gpus()))
'{} GPUs are not available for this test. {} GPUs are available'
.format(1, context.num_gpus()))
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
'--dtype', 'fp16'])
ncf_keras_main.main,
tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-num_gpus', '1', '--dtype', 'fp16'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
if context.num_gpus() < 1:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'.
format(1, context.num_gpus()))
'{} GPUs are not available for this test. {} GPUs are available'
.format(1, context.num_gpus()))
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
'--dtype', 'fp16',
'--keras_use_ctl'])
ncf_keras_main.main,
tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-num_gpus', '1', '--dtype', 'fp16', '--keras_use_ctl'])
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
def test_end_to_end_keras_2_gpu_fp16(self):
if context.num_gpus() < 2:
self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available".
format(2, context.num_gpus()))
'{} GPUs are not available for this test. {} GPUs are available'
.format(2, context.num_gpus()))
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2',
'--dtype', 'fp16'])
ncf_keras_main.main,
tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-num_gpus', '2', '--dtype', 'fp16'])
if __name__ == "__main__":
if __name__ == '__main__':
tf.test.main()
......@@ -111,8 +111,7 @@ def neumf_model_fn(features, labels, mode, params):
loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(
labels=labels,
logits=softmax_logits,
weights=tf.cast(valid_pt_mask, tf.float32)
)
weights=tf.cast(valid_pt_mask, tf.float32))
tf.identity(loss, name="cross_entropy")
......@@ -196,15 +195,19 @@ def construct_model(user_input: tf.Tensor, item_input: tf.Tensor,
# GMF part
mf_user_latent = tf.keras.layers.Lambda(
mf_slice_fn, name="embedding_user_mf")(embedding_user)
mf_slice_fn, name="embedding_user_mf")(
embedding_user)
mf_item_latent = tf.keras.layers.Lambda(
mf_slice_fn, name="embedding_item_mf")(embedding_item)
mf_slice_fn, name="embedding_item_mf")(
embedding_item)
# MLP part
mlp_user_latent = tf.keras.layers.Lambda(
mlp_slice_fn, name="embedding_user_mlp")(embedding_user)
mlp_slice_fn, name="embedding_user_mlp")(
embedding_user)
mlp_item_latent = tf.keras.layers.Lambda(
mlp_slice_fn, name="embedding_item_mlp")(embedding_item)
mlp_slice_fn, name="embedding_item_mlp")(
embedding_item)
# Element-wise multiply
mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
......@@ -225,8 +228,11 @@ def construct_model(user_input: tf.Tensor, item_input: tf.Tensor,
# Final prediction layer
logits = tf.keras.layers.Dense(
1, activation=None, kernel_initializer="lecun_uniform",
name=movielens.RATING_COLUMN)(predict_vector)
1,
activation=None,
kernel_initializer="lecun_uniform",
name=movielens.RATING_COLUMN)(
predict_vector)
# Print model topology.
model = tf.keras.models.Model([user_input, item_input], logits)
......@@ -263,8 +269,7 @@ def _get_estimator_spec_with_metrics(logits: tf.Tensor,
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=cross_entropy,
eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights)
)
eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights))
def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
......@@ -335,9 +340,13 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
# Examples are provided by the eval Dataset in a structured format, so eval
# labels can be reconstructed on the fly.
eval_labels = tf.reshape(shape=(-1,), tensor=tf.one_hot(
tf.zeros(shape=(logits_by_user.shape[0],), dtype=tf.int32) +
rconst.NUM_EVAL_NEGATIVES, logits_by_user.shape[1], dtype=tf.int32))
eval_labels = tf.reshape(
shape=(-1,),
tensor=tf.one_hot(
tf.zeros(shape=(logits_by_user.shape[0],), dtype=tf.int32) +
rconst.NUM_EVAL_NEGATIVES,
logits_by_user.shape[1],
dtype=tf.int32))
eval_labels_float = tf.cast(eval_labels, tf.float32)
......@@ -346,13 +355,14 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
# weights for the negative examples we compute a loss which is consistent with
# the training data. (And provides apples-to-apples comparison)
negative_scale_factor = num_training_neg / rconst.NUM_EVAL_NEGATIVES
example_weights = (
(eval_labels_float + (1 - eval_labels_float) * negative_scale_factor) *
(1 + rconst.NUM_EVAL_NEGATIVES) / (1 + num_training_neg))
example_weights = ((eval_labels_float +
(1 - eval_labels_float) * negative_scale_factor) *
(1 + rconst.NUM_EVAL_NEGATIVES) / (1 + num_training_neg))
# Tile metric weights back to logit dimensions
expanded_metric_weights = tf.reshape(tf.tile(
metric_weights[:, tf.newaxis], (1, rconst.NUM_EVAL_NEGATIVES + 1)), (-1,))
expanded_metric_weights = tf.reshape(
tf.tile(metric_weights[:, tf.newaxis],
(1, rconst.NUM_EVAL_NEGATIVES + 1)), (-1,))
# ignore padded examples
example_weights *= tf.cast(expanded_metric_weights, tf.float32)
......@@ -362,12 +372,15 @@ def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
def metric_fn(top_k_tensor, ndcg_tensor, weight_tensor):
return {
rconst.HR_KEY: tf.compat.v1.metrics.mean(top_k_tensor,
weights=weight_tensor,
name=rconst.HR_METRIC_NAME),
rconst.NDCG_KEY: tf.compat.v1.metrics.mean(ndcg_tensor,
weights=weight_tensor,
name=rconst.NDCG_METRIC_NAME)
rconst.HR_KEY:
tf.compat.v1.metrics.mean(
top_k_tensor, weights=weight_tensor,
name=rconst.HR_METRIC_NAME),
rconst.NDCG_KEY:
tf.compat.v1.metrics.mean(
ndcg_tensor,
weights=weight_tensor,
name=rconst.NDCG_METRIC_NAME)
}
return cross_entropy, metric_fn, in_top_k, ndcg, metric_weights
......@@ -405,27 +418,26 @@ def compute_top_k_and_ndcg(logits: tf.Tensor,
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices = tf.argsort(
logits_by_user, axis=1, direction="DESCENDING")
sort_indices = tf.argsort(logits_by_user, axis=1, direction="DESCENDING")
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position = tf.cast(tf.equal(sort_indices, rconst.NUM_EVAL_NEGATIVES),
tf.int32)
one_hot_position = tf.cast(
tf.equal(sort_indices, rconst.NUM_EVAL_NEGATIVES), tf.int32)
sparse_positions = tf.multiply(
one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :])
one_hot_position,
tf.range(logits_by_user.shape[1])[tf.newaxis, :])
position_vector = tf.reduce_sum(sparse_positions, axis=1)
in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
ndcg = tf.math.log(2.) / tf.math.log(
tf.cast(position_vector, tf.float32) + 2)
ndcg = tf.math.log(2.) / tf.math.log(tf.cast(position_vector, tf.float32) + 2)
ndcg *= in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1),
rconst.NUM_EVAL_NEGATIVES)
metric_weights = tf.not_equal(
tf.reduce_sum(duplicate_mask_by_user, axis=1), rconst.NUM_EVAL_NEGATIVES)
return in_top_k, ndcg, metric_weights, logits_by_user
......@@ -37,9 +37,7 @@ def permutation(args):
args: A size two tuple that will unpacked into the size of the permutation
and the random seed. This form is used because starmap is not universally
available.
returns:
A NumPy array containing a random permutation.
returns: A NumPy array containing a random permutation.
"""
x, seed = args
......@@ -53,8 +51,11 @@ def permutation(args):
def very_slightly_biased_randint(max_val_vector):
sample_dtype = np.uint64
out_dtype = max_val_vector.dtype
samples = np.random.randint(low=0, high=np.iinfo(sample_dtype).max,
size=max_val_vector.shape, dtype=sample_dtype)
samples = np.random.randint(
low=0,
high=np.iinfo(sample_dtype).max,
size=max_val_vector.shape,
dtype=sample_dtype)
return np.mod(samples, max_val_vector.astype(sample_dtype)).astype(out_dtype)
......@@ -88,5 +89,5 @@ def mask_duplicates(x, axis=1): # type: (np.ndarray, int) -> np.ndarray
# Duplicate values will have a difference of zero. By definition the first
# element is never a duplicate.
return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis],
inv_x_sort_ind], 0, 1)
return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis], inv_x_sort_ind],
0, 1)
......@@ -103,9 +103,9 @@ def minimize_using_explicit_allreduce(tape,
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables pairs. The callback functions will be
invoked in the list order and before gradients are allreduced.
With mixed precision training, the pre_allreduce_allbacks will be
applied on scaled_gradients. Default is no callbacks.
invoked in the list order and before gradients are allreduced. With
mixed precision training, the pre_allreduce_allbacks will be applied on
scaled_gradients. Default is no callbacks.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
......
......@@ -23,10 +23,18 @@ import tensorflow as tf
from official.utils.flags._conventions import help_wrap
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
epochs_between_evals=False, stop_threshold=False,
batch_size=True, num_gpu=False, hooks=False, export_dir=False,
distribution_strategy=False, run_eagerly=False):
def define_base(data_dir=True,
model_dir=True,
clean=False,
train_epochs=False,
epochs_between_evals=False,
stop_threshold=False,
batch_size=True,
num_gpu=False,
hooks=False,
export_dir=False,
distribution_strategy=False,
run_eagerly=False):
"""Register base flags.
Args:
......@@ -35,8 +43,8 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
clean: Create a flag for removing the model_dir.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
stop_threshold: Create a flag to specify a threshold accuracy or other eval
metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging.
......@@ -44,6 +52,7 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
distribution_strategy: Create a flag to specify which Distribution Strategy
to use.
run_eagerly: Create a flag to specify to run eagerly op by op.
Returns:
A list of flags for core.py to marks as key flags.
"""
......@@ -51,38 +60,48 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
if data_dir:
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp",
name="data_dir",
short_name="dd",
default="/tmp",
help=help_wrap("The location of the input data."))
key_flags.append("data_dir")
if model_dir:
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp",
name="model_dir",
short_name="md",
default="/tmp",
help=help_wrap("The location of the model checkpoint files."))
key_flags.append("model_dir")
if clean:
flags.DEFINE_boolean(
name="clean", default=False,
name="clean",
default=False,
help=help_wrap("If set, model_dir will be removed if it exists."))
key_flags.append("clean")
if train_epochs:
flags.DEFINE_integer(
name="train_epochs", short_name="te", default=1,
name="train_epochs",
short_name="te",
default=1,
help=help_wrap("The number of epochs used to train."))
key_flags.append("train_epochs")
if epochs_between_evals:
flags.DEFINE_integer(
name="epochs_between_evals", short_name="ebe", default=1,
name="epochs_between_evals",
short_name="ebe",
default=1,
help=help_wrap("The number of training epochs to run between "
"evaluations."))
key_flags.append("epochs_between_evals")
if stop_threshold:
flags.DEFINE_float(
name="stop_threshold", short_name="st",
name="stop_threshold",
short_name="st",
default=None,
help=help_wrap("If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
......@@ -90,7 +109,9 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
if batch_size:
flags.DEFINE_integer(
name="batch_size", short_name="bs", default=32,
name="batch_size",
short_name="bs",
default=32,
help=help_wrap("Batch size for training and evaluation. When using "
"multiple gpus, this is the global batch size for "
"all devices. For example, if the batch size is 32 "
......@@ -100,49 +121,52 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
if num_gpu:
flags.DEFINE_integer(
name="num_gpus", short_name="ng",
name="num_gpus",
short_name="ng",
default=1,
help=help_wrap(
"How many GPUs to use at each worker with the "
"DistributionStrategies API. The default is 1."))
help=help_wrap("How many GPUs to use at each worker with the "
"DistributionStrategies API. The default is 1."))
if run_eagerly:
flags.DEFINE_boolean(
name="run_eagerly", default=False,
name="run_eagerly",
default=False,
help="Run the model op by op without building a model function.")
if hooks:
flags.DEFINE_list(
name="hooks", short_name="hk", default="LoggingTensorHook",
name="hooks",
short_name="hk",
default="LoggingTensorHook",
help=help_wrap(
u"A list of (case insensitive) strings to specify the names of "
u"training hooks. Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See hooks_helper "
u"for details.")
)
u"for details."))
key_flags.append("hooks")
if export_dir:
flags.DEFINE_string(
name="export_dir", short_name="ed", default=None,
name="export_dir",
short_name="ed",
default=None,
help=help_wrap("If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links.")
)
"See the README for more details and relevant links."))
key_flags.append("export_dir")
if distribution_strategy:
flags.DEFINE_string(
name="distribution_strategy", short_name="ds", default="mirrored",
name="distribution_strategy",
short_name="ds",
default="mirrored",
help=help_wrap("The Distribution Strategy to use for training. "
"Accepted values are 'off', 'one_device', "
"'mirrored', 'parameter_server', 'collective', "
"case insensitive. 'off' means not to use "
"Distribution Strategy; 'default' means to choose "
"from `MirroredStrategy` or `OneDeviceStrategy` "
"according to the number of GPUs.")
)
"according to the number of GPUs."))
return key_flags
......
......@@ -25,7 +25,8 @@ from official.utils.flags._conventions import help_wrap
def define_log_steps():
flags.DEFINE_integer(
name="log_steps", default=100,
name="log_steps",
default=100,
help="Frequency with which to log timing information with TimeHistory.")
return []
......@@ -45,13 +46,16 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
key_flags = []
flags.DEFINE_enum(
name="benchmark_logger_type", default="BaseBenchmarkLogger",
name="benchmark_logger_type",
default="BaseBenchmarkLogger",
enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger"],
help=help_wrap("The type of benchmark logger to use. Defaults to using "
"BaseBenchmarkLogger which logs to STDOUT. Different "
"loggers will require other flags to be able to work."))
flags.DEFINE_string(
name="benchmark_test_id", short_name="bti", default=None,
name="benchmark_test_id",
short_name="bti",
default=None,
help=help_wrap("The unique test ID of the benchmark run. It could be the "
"combination of key parameters. It is hardware "
"independent and could be used compare the performance "
......@@ -63,34 +67,43 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
if benchmark_log_dir:
flags.DEFINE_string(
name="benchmark_log_dir", short_name="bld", default=None,
help=help_wrap("The location of the benchmark logging.")
)
name="benchmark_log_dir",
short_name="bld",
default=None,
help=help_wrap("The location of the benchmark logging."))
if bigquery_uploader:
flags.DEFINE_string(
name="gcp_project", short_name="gp", default=None,
name="gcp_project",
short_name="gp",
default=None,
help=help_wrap(
"The GCP project name where the benchmark will be uploaded."))
flags.DEFINE_string(
name="bigquery_data_set", short_name="bds", default="test_benchmark",
name="bigquery_data_set",
short_name="bds",
default="test_benchmark",
help=help_wrap(
"The Bigquery dataset name where the benchmark will be uploaded."))
flags.DEFINE_string(
name="bigquery_run_table", short_name="brt", default="benchmark_run",
name="bigquery_run_table",
short_name="brt",
default="benchmark_run",
help=help_wrap("The Bigquery table name where the benchmark run "
"information will be uploaded."))
flags.DEFINE_string(
name="bigquery_run_status_table", short_name="brst",
name="bigquery_run_status_table",
short_name="brst",
default="benchmark_run_status",
help=help_wrap("The Bigquery table name where the benchmark run "
"status information will be uploaded."))
flags.DEFINE_string(
name="bigquery_metric_table", short_name="bmt",
name="bigquery_metric_table",
short_name="bmt",
default="benchmark_metric",
help=help_wrap("The Bigquery table name where the benchmark metric "
"information will be uploaded."))
......@@ -98,7 +111,7 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
@flags.multi_flags_validator(
["benchmark_logger_type", "benchmark_log_dir"],
message="--benchmark_logger_type=BenchmarkFileLogger will require "
"--benchmark_log_dir being set")
"--benchmark_log_dir being set")
def _check_benchmark_log_dir(flags_dict):
benchmark_logger_type = flags_dict["benchmark_logger_type"]
if benchmark_logger_type == "BenchmarkFileLogger":
......
......@@ -25,13 +25,12 @@ import functools
from absl import app as absl_app
from absl import flags
# This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*.
_help_wrap = functools.partial(flags.text_wrap, length=80, indent="",
firstline_indent="\n")
_help_wrap = functools.partial(
flags.text_wrap, length=80, indent="", firstline_indent="\n")
# Pretty formatting causes issues when utf-8 is not installed on a system.
......@@ -46,6 +45,7 @@ def _stdout_utf8():
if _stdout_utf8():
help_wrap = _help_wrap
else:
def help_wrap(text, *args, **kwargs):
return _help_wrap(text, *args, **kwargs).replace(u"\ufeff", u"")
......
......@@ -26,11 +26,13 @@ from official.utils.flags._conventions import help_wrap
def require_cloud_storage(flag_names):
"""Register a validator to check directory flags.
Args:
flag_names: An iterable of strings containing the names of flags to be
checked.
"""
msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
@flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
def _path_check(flag_values): # pylint: disable=missing-docstring
if flag_values["tpu"] is None:
......@@ -47,8 +49,10 @@ def require_cloud_storage(flag_names):
def define_device(tpu=True):
"""Register device specific flags.
Args:
tpu: Create flags to specify TPU operation.
Returns:
A list of flags for core.py to marks as key flags.
"""
......@@ -57,7 +61,8 @@ def define_device(tpu=True):
if tpu:
flags.DEFINE_string(
name="tpu", default=None,
name="tpu",
default=None,
help=help_wrap(
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a "
......@@ -66,20 +71,24 @@ def define_device(tpu=True):
key_flags.append("tpu")
flags.DEFINE_string(
name="tpu_zone", default=None,
name="tpu_zone",
default=None,
help=help_wrap(
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."))
flags.DEFINE_string(
name="tpu_gcp_project", default=None,
name="tpu_gcp_project",
default=None,
help=help_wrap(
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."))
flags.DEFINE_integer(name="num_tpu_shards", default=8,
help=help_wrap("Number of shards (TPU chips)."))
flags.DEFINE_integer(
name="num_tpu_shards",
default=8,
help=help_wrap("Number of shards (TPU chips)."))
return key_flags
......@@ -38,7 +38,8 @@ def define_distribution(worker_hosts=True, task_index=True):
if worker_hosts:
flags.DEFINE_string(
name='worker_hosts', default=None,
name='worker_hosts',
default=None,
help=help_wrap(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
......@@ -47,7 +48,8 @@ def define_distribution(worker_hosts=True, task_index=True):
if task_index:
flags.DEFINE_integer(
name='task_index', default=-1,
name='task_index',
default=-1,
help=help_wrap('If multi-worker training, the task_index of this '
'worker.'))
......
......@@ -37,7 +37,9 @@ def define_image(data_format=True):
if data_format:
flags.DEFINE_enum(
name="data_format", short_name="df", default=None,
name="data_format",
short_name="df",
default=None,
enum_values=["channels_first", "channels_last"],
help=help_wrap(
"A flag to override the data format used in the model. "
......
......@@ -20,12 +20,11 @@ from __future__ import print_function
import multiprocessing
from absl import flags # pylint: disable=g-bad-import-order
import tensorflow as tf # pylint: disable=g-bad-import-order
from absl import flags # pylint: disable=g-bad-import-order
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags._conventions import help_wrap
# Map string to TensorFlow dtype
DTYPE_MAP = {
"fp16": tf.float16,
......@@ -55,15 +54,22 @@ def get_loss_scale(flags_obj, default_for_fp16):
return default_for_fp16
def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
synthetic_data=False, max_train_steps=False, dtype=False,
all_reduce_alg=False, num_packs=False,
def define_performance(num_parallel_calls=False,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=False,
all_reduce_alg=False,
num_packs=False,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
datasets_num_parallel_batches=False,
dynamic_loss_scale=False, fp16_implementation=False,
dynamic_loss_scale=False,
fp16_implementation=False,
loss_scale=False,
tf_data_experimental_slack=False, enable_xla=False,
tf_data_experimental_slack=False,
enable_xla=False,
training_dataset_cache=False):
"""Register flags for specifying performance tuning arguments.
......@@ -72,8 +78,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number
of training steps
max_train_steps: Create a flags to allow specification of maximum number of
training steps
dtype: Create flags for specifying dtype.
all_reduce_alg: If set forces a specific algorithm for multi-gpu.
num_packs: If set provides number of packs for MirroredStrategy's cross
......@@ -81,7 +87,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag.
......@@ -91,8 +97,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
`experimental_slack` option.
enable_xla: Determines if XLA (auto clustering) is turned on.
training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
Returns:
A list of flags for core.py to marks as key flags.
......@@ -101,7 +107,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
key_flags = []
if num_parallel_calls:
flags.DEFINE_integer(
name="num_parallel_calls", short_name="npc",
name="num_parallel_calls",
short_name="npc",
default=multiprocessing.cpu_count(),
help=help_wrap("The number of records that are processed in parallel "
"during input processing. This can be optimized per "
......@@ -111,20 +118,25 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if inter_op:
flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0,
name="inter_op_parallelism_threads",
short_name="inter",
default=0,
help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details.")
)
"See TensorFlow config.proto for details."))
if intra_op:
flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0,
name="intra_op_parallelism_threads",
short_name="intra",
default=0,
help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."))
if synthetic_data:
flags.DEFINE_bool(
name="use_synthetic_data", short_name="synth", default=False,
name="use_synthetic_data",
short_name="synth",
default=False,
help=help_wrap(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
......@@ -132,16 +144,20 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if max_train_steps:
flags.DEFINE_integer(
name="max_train_steps", short_name="mts", default=None, help=help_wrap(
name="max_train_steps",
short_name="mts",
default=None,
help=help_wrap(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
"--train_epochs=1 when using this flag."))
if dtype:
flags.DEFINE_enum(
name="dtype", short_name="dt", default="fp32",
name="dtype",
short_name="dt",
default="fp32",
enum_values=DTYPE_MAP.keys(),
help=help_wrap("The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
......@@ -155,8 +171,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes.{}"
)
"for fp16 is 128 and 1 for all other dtypes.{}")
if dynamic_loss_scale:
loss_scale_help_text = loss_scale_help_text.format(
"This can be an int/float or the string 'dynamic'",
......@@ -171,11 +186,13 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
loss_scale_validation_msg = "loss_scale should be a positive int/float."
if loss_scale:
flags.DEFINE_string(
name="loss_scale", short_name="ls", default=None,
name="loss_scale",
short_name="ls",
default=None,
help=help_wrap(loss_scale_help_text))
@flags.validator(flag_name="loss_scale",
message=loss_scale_validation_msg)
@flags.validator(
flag_name="loss_scale", message=loss_scale_validation_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid."""
if loss_scale is None:
......@@ -193,7 +210,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if fp16_implementation:
flags.DEFINE_enum(
name="fp16_implementation", default="keras",
name="fp16_implementation",
default="keras",
enum_values=("keras', 'graph_rewrite"),
help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no "
......@@ -202,8 +220,8 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
"tf.train.experimental.enable_mixed_precision_graph_rewrite "
"API."))
@flags.multi_flags_validator(["fp16_implementation", "dtype",
"loss_scale"])
@flags.multi_flags_validator(
["fp16_implementation", "dtype", "loss_scale"])
def _check_fp16_implementation(flags_dict):
"""Validator to check fp16_implementation flag is valid."""
if (flags_dict["fp16_implementation"] == "graph_rewrite" and
......@@ -214,7 +232,9 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if all_reduce_alg:
flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None,
name="all_reduce_alg",
short_name="ara",
default=None,
help=help_wrap("Defines the algorithm to use for performing all-reduce."
"When specified with MirroredStrategy for single "
"worker, this controls "
......@@ -226,24 +246,26 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
if num_packs:
flags.DEFINE_integer(
name="num_packs", default=1,
name="num_packs",
default=1,
help=help_wrap("Sets `num_packs` in the cross device ops used in "
"MirroredStrategy. For details, see "
"tf.distribute.NcclAllReduce."))
if tf_gpu_thread_mode:
flags.DEFINE_string(
name="tf_gpu_thread_mode", short_name="gt_mode", default=None,
name="tf_gpu_thread_mode",
short_name="gt_mode",
default=None,
help=help_wrap(
"Whether and how the GPU device uses its own threadpool.")
)
"Whether and how the GPU device uses its own threadpool."))
flags.DEFINE_integer(
name="per_gpu_thread_count", short_name="pgtc", default=0,
help=help_wrap(
"The number of threads to use for GPU. Only valid when "
"tf_gpu_thread_mode is not global.")
)
name="per_gpu_thread_count",
short_name="pgtc",
default=0,
help=help_wrap("The number of threads to use for GPU. Only valid when "
"tf_gpu_thread_mode is not global."))
if datasets_num_private_threads:
flags.DEFINE_integer(
......@@ -251,8 +273,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
default=None,
help=help_wrap(
"Number of threads for a private threadpool created for all"
"datasets computation..")
)
"datasets computation.."))
if datasets_num_parallel_batches:
flags.DEFINE_integer(
......@@ -260,8 +281,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
default=None,
help=help_wrap(
"Determines how many batches to process in parallel when using "
"map and batch from tf.data.")
)
"map and batch from tf.data."))
if training_dataset_cache:
flags.DEFINE_boolean(
......@@ -270,20 +290,19 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
help=help_wrap(
"Determines whether to cache the training dataset on workers. "
"Typically used to improve training performance when training "
"data is in remote storage and can fit into worker memory.")
)
"data is in remote storage and can fit into worker memory."))
if tf_data_experimental_slack:
flags.DEFINE_boolean(
name="tf_data_experimental_slack",
default=False,
help=help_wrap(
"Whether to enable tf.data's `experimental_slack` option.")
)
"Whether to enable tf.data's `experimental_slack` option."))
if enable_xla:
flags.DEFINE_boolean(
name="enable_xla", default=False,
name="enable_xla",
default=False,
help="Whether to enable XLA auto jit compilation")
return key_flags
......@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import sys
from six.moves import shlex_quote
from absl import app as absl_app
......@@ -65,6 +66,7 @@ def register_key_flags_in_core(f):
def core_fn(*args, **kwargs):
key_flags = f(*args, **kwargs)
[flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned
return core_fn
......@@ -80,16 +82,15 @@ define_performance = register_key_flags_in_core(_performance.define_performance)
define_distribution = register_key_flags_in_core(
_distribution.define_distribution)
help_wrap = _conventions.help_wrap
get_num_gpus = _base.get_num_gpus
get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale
DTYPE_MAP = _performance.DTYPE_MAP
require_cloud_storage = _device.require_cloud_storage
def _get_nondefault_flags_as_dict():
"""Returns the nondefault flags as a dict from flag name to value."""
nondefault_flags = {}
......
......@@ -22,12 +22,20 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags():
flags_core.define_base(clean=True, num_gpu=False, stop_threshold=True,
hooks=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_base(
clean=True,
num_gpu=False,
stop_threshold=True,
hooks=True,
train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True,
dynamic_loss_scale=True, loss_scale=True, synthetic_data=True,
num_parallel_calls=True,
inter_op=True,
intra_op=True,
dynamic_loss_scale=True,
loss_scale=True,
synthetic_data=True,
dtype=True)
flags_core.define_image()
flags_core.define_benchmark()
......@@ -41,8 +49,7 @@ class BaseTester(unittest.TestCase):
define_flags()
def test_default_setting(self):
"""Test to ensure fields exist and defaults can be set.
"""
"""Test to ensure fields exist and defaults can be set."""
defaults = dict(
data_dir="dfgasf",
......@@ -54,8 +61,7 @@ class BaseTester(unittest.TestCase):
num_parallel_calls=18,
inter_op_parallelism_threads=5,
intra_op_parallelism_threads=10,
data_format="channels_first"
)
data_format="channels_first")
flags_core.set_defaults(**defaults)
flags_core.parse_flags()
......@@ -77,8 +83,7 @@ class BaseTester(unittest.TestCase):
assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_booleans(self):
"""Test to ensure boolean flags trigger as expected.
"""
"""Test to ensure boolean flags trigger as expected."""
flags_core.parse_flags([__file__, "--use_synthetic_data"])
......@@ -87,35 +92,33 @@ class BaseTester(unittest.TestCase):
def test_parse_dtype_info(self):
flags_core.parse_flags([__file__, "--dtype", "fp16"])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 2)
self.assertEqual(
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 2)
flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 5)
flags_core.parse_flags([__file__, "--dtype", "fp16", "--loss_scale", "5"])
self.assertEqual(
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5)
flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), "dynamic")
self.assertEqual(
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), "dynamic")
flags_core.parse_flags([__file__, "--dtype", "fp32"])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 1)
self.assertEqual(
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 1)
flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 5)
self.assertEqual(
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5)
with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "int8"])
with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "fp16",
"--loss_scale", "abc"])
flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "abc"])
def test_get_nondefault_flags_as_str(self):
defaults = dict(
......@@ -123,8 +126,7 @@ class BaseTester(unittest.TestCase):
data_dir="abc",
hooks=["LoggingTensorHook"],
stop_threshold=1.5,
use_synthetic_data=False
)
use_synthetic_data=False)
flags_core.set_defaults(**defaults)
flags_core.parse_flags()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment