Unverified Commit 03b4a0af authored by Hongjun Choi's avatar Hongjun Choi Committed by GitHub
Browse files

Merged commit includes the following changes: (#7430)

262988559  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Enable NCF TF 2.0 model to run on TPUStrategy.

--
262971756  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change

262967691  by hongkuny<hongkuny@google.com>:

    Internal

--

PiperOrigin-RevId: 262988559
parent 3a14837d
...@@ -143,37 +143,32 @@ class DatasetManager(object): ...@@ -143,37 +143,32 @@ class DatasetManager(object):
if is_training: if is_training:
return { return {
movielens.USER_COLUMN: movielens.USER_COLUMN:
tf.io.FixedLenFeature([batch_size], dtype=tf.int64), tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
movielens.ITEM_COLUMN: movielens.ITEM_COLUMN:
tf.io.FixedLenFeature([batch_size], dtype=tf.int64), tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
rconst.VALID_POINT_MASK: rconst.VALID_POINT_MASK:
tf.io.FixedLenFeature([batch_size], dtype=tf.int64), tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
"labels": "labels":
tf.io.FixedLenFeature([batch_size], dtype=tf.int64) tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64)
} }
else: else:
return { return {
movielens.USER_COLUMN: movielens.USER_COLUMN:
tf.io.FixedLenFeature([batch_size], dtype=tf.int64), tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
movielens.ITEM_COLUMN: movielens.ITEM_COLUMN:
tf.io.FixedLenFeature([batch_size], dtype=tf.int64), tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64),
rconst.DUPLICATE_MASK: rconst.DUPLICATE_MASK:
tf.io.FixedLenFeature([batch_size], dtype=tf.int64) tf.io.FixedLenFeature([batch_size, 1], dtype=tf.int64)
} }
features = tf.io.parse_single_example( features = tf.io.parse_single_example(
serialized_data, _get_feature_map(batch_size, is_training=is_training)) serialized_data, _get_feature_map(batch_size, is_training=is_training))
users = tf.reshape( users = tf.cast(features[movielens.USER_COLUMN], rconst.USER_DTYPE)
tf.cast(features[movielens.USER_COLUMN], rconst.USER_DTYPE), items = tf.cast(features[movielens.ITEM_COLUMN], rconst.ITEM_DTYPE)
(batch_size,))
items = tf.reshape(
tf.cast(features[movielens.ITEM_COLUMN], rconst.ITEM_DTYPE),
(batch_size,))
if is_training: if is_training:
valid_point_mask = tf.reshape( valid_point_mask = tf.cast(features[rconst.VALID_POINT_MASK], tf.bool)
tf.cast(features[movielens.ITEM_COLUMN], tf.bool), (batch_size,)) fake_dup_mask = tf.zeros_like(users)
fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN])
return { return {
movielens.USER_COLUMN: users, movielens.USER_COLUMN: users,
movielens.ITEM_COLUMN: items, movielens.ITEM_COLUMN: items,
...@@ -184,20 +179,15 @@ class DatasetManager(object): ...@@ -184,20 +179,15 @@ class DatasetManager(object):
rconst.DUPLICATE_MASK: fake_dup_mask rconst.DUPLICATE_MASK: fake_dup_mask
} }
else: else:
labels = tf.reshape( labels = tf.cast(tf.zeros_like(users), tf.bool)
tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool), fake_valid_pt_mask = tf.cast(tf.zeros_like(users), tf.bool)
(batch_size, 1))
fake_valid_pt_mask = tf.cast(
tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
return { return {
movielens.USER_COLUMN: movielens.USER_COLUMN:
users, users,
movielens.ITEM_COLUMN: movielens.ITEM_COLUMN:
items, items,
rconst.DUPLICATE_MASK: rconst.DUPLICATE_MASK:
tf.reshape( tf.cast(features[rconst.DUPLICATE_MASK], tf.bool),
tf.cast(features[rconst.DUPLICATE_MASK], tf.bool),
(batch_size,)),
rconst.VALID_POINT_MASK: rconst.VALID_POINT_MASK:
fake_valid_pt_mask, fake_valid_pt_mask,
rconst.TRAIN_LABEL_KEY: rconst.TRAIN_LABEL_KEY:
...@@ -221,8 +211,8 @@ class DatasetManager(object): ...@@ -221,8 +211,8 @@ class DatasetManager(object):
if self._is_training: if self._is_training:
mask_start_index = data.pop(rconst.MASK_START_INDEX) mask_start_index = data.pop(rconst.MASK_START_INDEX)
batch_size = data[movielens.ITEM_COLUMN].shape[0] batch_size = data[movielens.ITEM_COLUMN].shape[0]
data[rconst.VALID_POINT_MASK] = np.less( data[rconst.VALID_POINT_MASK] = np.expand_dims(
np.arange(batch_size), mask_start_index) np.less(np.arange(batch_size), mask_start_index), -1)
if self._stream_files: if self._stream_files:
example_bytes = self.serialize(data) example_bytes = self.serialize(data)
...@@ -313,19 +303,21 @@ class DatasetManager(object): ...@@ -313,19 +303,21 @@ class DatasetManager(object):
else: else:
types = {movielens.USER_COLUMN: rconst.USER_DTYPE, types = {movielens.USER_COLUMN: rconst.USER_DTYPE,
movielens.ITEM_COLUMN: rconst.ITEM_DTYPE} movielens.ITEM_COLUMN: rconst.ITEM_DTYPE}
shapes = {movielens.USER_COLUMN: tf.TensorShape([batch_size]), shapes = {
movielens.ITEM_COLUMN: tf.TensorShape([batch_size])} movielens.USER_COLUMN: tf.TensorShape([batch_size, 1]),
movielens.ITEM_COLUMN: tf.TensorShape([batch_size, 1])
}
if self._is_training: if self._is_training:
types[rconst.VALID_POINT_MASK] = np.bool types[rconst.VALID_POINT_MASK] = np.bool
shapes[rconst.VALID_POINT_MASK] = tf.TensorShape([batch_size]) shapes[rconst.VALID_POINT_MASK] = tf.TensorShape([batch_size, 1])
types = (types, np.bool) types = (types, np.bool)
shapes = (shapes, tf.TensorShape([batch_size])) shapes = (shapes, tf.TensorShape([batch_size, 1]))
else: else:
types[rconst.DUPLICATE_MASK] = np.bool types[rconst.DUPLICATE_MASK] = np.bool
shapes[rconst.DUPLICATE_MASK] = tf.TensorShape([batch_size]) shapes[rconst.DUPLICATE_MASK] = tf.TensorShape([batch_size, 1])
data_generator = functools.partial( data_generator = functools.partial(
self.data_generator, epochs_between_evals=epochs_between_evals) self.data_generator, epochs_between_evals=epochs_between_evals)
...@@ -554,12 +546,17 @@ class BaseDataConstructor(threading.Thread): ...@@ -554,12 +546,17 @@ class BaseDataConstructor(threading.Thread):
items = np.concatenate([items, item_pad]) items = np.concatenate([items, item_pad])
labels = np.concatenate([labels, label_pad]) labels = np.concatenate([labels, label_pad])
self._train_dataset.put(i, { self._train_dataset.put(
movielens.USER_COLUMN: users, i, {
movielens.ITEM_COLUMN: items, movielens.USER_COLUMN:
rconst.MASK_START_INDEX: np.array(mask_start_index, dtype=np.int32), np.reshape(users, (self.train_batch_size, 1)),
"labels": labels, movielens.ITEM_COLUMN:
}) np.reshape(items, (self.train_batch_size, 1)),
rconst.MASK_START_INDEX:
np.array(mask_start_index, dtype=np.int32),
"labels":
np.reshape(labels, (self.train_batch_size, 1)),
})
def _wait_to_construct_train_epoch(self): def _wait_to_construct_train_epoch(self):
count = 0 count = 0
...@@ -649,11 +646,15 @@ class BaseDataConstructor(threading.Thread): ...@@ -649,11 +646,15 @@ class BaseDataConstructor(threading.Thread):
users, items, duplicate_mask = self._assemble_eval_batch( users, items, duplicate_mask = self._assemble_eval_batch(
users, positive_items, negative_items, self._eval_users_per_batch) users, positive_items, negative_items, self._eval_users_per_batch)
self._eval_dataset.put(i, { self._eval_dataset.put(
movielens.USER_COLUMN: users.flatten(), i, {
movielens.ITEM_COLUMN: items.flatten(), movielens.USER_COLUMN:
rconst.DUPLICATE_MASK: duplicate_mask.flatten(), np.reshape(users.flatten(), (self.eval_batch_size, 1)),
}) movielens.ITEM_COLUMN:
np.reshape(items.flatten(), (self.eval_batch_size, 1)),
rconst.DUPLICATE_MASK:
np.reshape(duplicate_mask.flatten(), (self.eval_batch_size, 1)),
})
def _construct_eval_epoch(self): def _construct_eval_epoch(self):
"""Loop to construct data for evaluation.""" """Loop to construct data for evaluation."""
...@@ -720,24 +721,37 @@ class DummyConstructor(threading.Thread): ...@@ -720,24 +721,37 @@ class DummyConstructor(threading.Thread):
num_users = params["num_users"] num_users = params["num_users"]
num_items = params["num_items"] num_items = params["num_items"]
users = tf.random.uniform([batch_size], dtype=tf.int32, minval=0, users = tf.random.uniform([batch_size, 1],
dtype=tf.int32,
minval=0,
maxval=num_users) maxval=num_users)
items = tf.random.uniform([batch_size], dtype=tf.int32, minval=0, items = tf.random.uniform([batch_size, 1],
dtype=tf.int32,
minval=0,
maxval=num_items) maxval=num_items)
if is_training: if is_training:
valid_point_mask = tf.cast(tf.random.uniform( valid_point_mask = tf.cast(
[batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) tf.random.uniform([batch_size, 1],
labels = tf.cast(tf.random.uniform( dtype=tf.int32,
[batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) minval=0,
maxval=2), tf.bool)
labels = tf.cast(
tf.random.uniform([batch_size, 1],
dtype=tf.int32,
minval=0,
maxval=2), tf.bool)
data = { data = {
movielens.USER_COLUMN: users, movielens.USER_COLUMN: users,
movielens.ITEM_COLUMN: items, movielens.ITEM_COLUMN: items,
rconst.VALID_POINT_MASK: valid_point_mask, rconst.VALID_POINT_MASK: valid_point_mask,
}, labels }, labels
else: else:
dupe_mask = tf.cast(tf.random.uniform([batch_size], dtype=tf.int32, dupe_mask = tf.cast(
minval=0, maxval=2), tf.bool) tf.random.uniform([batch_size, 1],
dtype=tf.int32,
minval=0,
maxval=2), tf.bool)
data = { data = {
movielens.USER_COLUMN: users, movielens.USER_COLUMN: users,
movielens.ITEM_COLUMN: items, movielens.ITEM_COLUMN: items,
......
...@@ -168,8 +168,11 @@ class BaseTest(tf.test.TestCase): ...@@ -168,8 +168,11 @@ class BaseTest(tf.test.TestCase):
md5 = hashlib.md5() md5 = hashlib.md5()
for features, labels in first_epoch: for features, labels in first_epoch:
data_list = [ data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN].flatten(),
features[rconst.VALID_POINT_MASK], labels] features[movielens.ITEM_COLUMN].flatten(),
features[rconst.VALID_POINT_MASK].flatten(),
labels.flatten()
]
for i in data_list: for i in data_list:
md5.update(i.tobytes()) md5.update(i.tobytes())
...@@ -216,8 +219,10 @@ class BaseTest(tf.test.TestCase): ...@@ -216,8 +219,10 @@ class BaseTest(tf.test.TestCase):
md5 = hashlib.md5() md5 = hashlib.md5()
for features in eval_data: for features in eval_data:
data_list = [ data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN].flatten(),
features[rconst.DUPLICATE_MASK]] features[movielens.ITEM_COLUMN].flatten(),
features[rconst.DUPLICATE_MASK].flatten()
]
for i in data_list: for i in data_list:
md5.update(i.tobytes()) md5.update(i.tobytes())
...@@ -276,8 +281,11 @@ class BaseTest(tf.test.TestCase): ...@@ -276,8 +281,11 @@ class BaseTest(tf.test.TestCase):
md5 = hashlib.md5() md5 = hashlib.md5()
for features, labels in results: for features, labels in results:
data_list = [ data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN].flatten(),
features[rconst.VALID_POINT_MASK], labels] features[movielens.ITEM_COLUMN].flatten(),
features[rconst.VALID_POINT_MASK].flatten(),
labels.flatten()
]
for i in data_list: for i in data_list:
md5.update(i.tobytes()) md5.update(i.tobytes())
......
...@@ -37,7 +37,6 @@ from official.utils.flags import core as flags_core ...@@ -37,7 +37,6 @@ from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -60,13 +59,8 @@ def get_inputs(params): ...@@ -60,13 +59,8 @@ def get_inputs(params):
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params, dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, params=params,
constructor_type=FLAGS.constructor_type, constructor_type=FLAGS.constructor_type,
deterministic=FLAGS.seed is not None) deterministic=FLAGS.seed is not None)
num_train_steps = producer.train_batches_per_epoch
num_train_steps = (producer.train_batches_per_epoch // num_eval_steps = producer.eval_batches_per_epoch
params["batches_per_step"])
num_eval_steps = (producer.eval_batches_per_epoch //
params["batches_per_step"])
assert not producer.train_batches_per_epoch % params["batches_per_step"]
assert not producer.eval_batches_per_epoch % params["batches_per_step"]
return num_users, num_items, num_train_steps, num_eval_steps, producer return num_users, num_items, num_train_steps, num_eval_steps, producer
...@@ -74,18 +68,13 @@ def get_inputs(params): ...@@ -74,18 +68,13 @@ def get_inputs(params):
def parse_flags(flags_obj): def parse_flags(flags_obj):
"""Convenience function to turn flags into params.""" """Convenience function to turn flags into params."""
num_gpus = flags_core.get_num_gpus(flags_obj) num_gpus = flags_core.get_num_gpus(flags_obj)
num_devices = FLAGS.num_tpu_shards if FLAGS.tpu else num_gpus or 1
batch_size = (flags_obj.batch_size + num_devices - 1) // num_devices
eval_divisor = (rconst.NUM_EVAL_NEGATIVES + 1) * num_devices batch_size = flags_obj.batch_size
eval_batch_size = flags_obj.eval_batch_size or flags_obj.batch_size eval_batch_size = flags_obj.eval_batch_size or flags_obj.batch_size
eval_batch_size = ((eval_batch_size + eval_divisor - 1) //
eval_divisor * eval_divisor // num_devices)
return { return {
"train_epochs": flags_obj.train_epochs, "train_epochs": flags_obj.train_epochs,
"batches_per_step": num_devices, "batches_per_step": 1,
"use_seed": flags_obj.seed is not None, "use_seed": flags_obj.seed is not None,
"batch_size": batch_size, "batch_size": batch_size,
"eval_batch_size": eval_batch_size, "eval_batch_size": eval_batch_size,
...@@ -95,6 +84,7 @@ def parse_flags(flags_obj): ...@@ -95,6 +84,7 @@ def parse_flags(flags_obj):
"mf_regularization": flags_obj.mf_regularization, "mf_regularization": flags_obj.mf_regularization,
"mlp_reg_layers": [float(reg) for reg in flags_obj.mlp_regularization], "mlp_reg_layers": [float(reg) for reg in flags_obj.mlp_regularization],
"num_neg": flags_obj.num_neg, "num_neg": flags_obj.num_neg,
"distribution_strategy": flags_obj.distribution_strategy,
"num_gpus": num_gpus, "num_gpus": num_gpus,
"use_tpu": flags_obj.tpu is not None, "use_tpu": flags_obj.tpu is not None,
"tpu": flags_obj.tpu, "tpu": flags_obj.tpu,
...@@ -115,7 +105,7 @@ def parse_flags(flags_obj): ...@@ -115,7 +105,7 @@ def parse_flags(flags_obj):
} }
def get_distribution_strategy(params): def get_v1_distribution_strategy(params):
"""Returns the distribution strategy to use.""" """Returns the distribution strategy to use."""
if params["use_tpu"]: if params["use_tpu"]:
# Some of the networking libraries are quite chatty. # Some of the networking libraries are quite chatty.
......
...@@ -66,7 +66,7 @@ def construct_estimator(model_dir, params): ...@@ -66,7 +66,7 @@ def construct_estimator(model_dir, params):
Returns: Returns:
An Estimator or TPUEstimator. An Estimator or TPUEstimator.
""" """
distribution = ncf_common.get_distribution_strategy(params) distribution = ncf_common.get_v1_distribution_strategy(params)
run_config = tf.estimator.RunConfig(train_distribute=distribution, run_config = tf.estimator.RunConfig(train_distribute=distribution,
eval_distribute=distribution) eval_distribute=distribution)
......
...@@ -82,7 +82,6 @@ def create_dataset_from_data_producer(producer, params): ...@@ -82,7 +82,6 @@ def create_dataset_from_data_producer(producer, params):
Returns: Returns:
Processed training features. Processed training features.
""" """
labels = tf.expand_dims(labels, -1)
fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN]) fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN])
features[rconst.DUPLICATE_MASK] = fake_dup_mask features[rconst.DUPLICATE_MASK] = fake_dup_mask
features[rconst.TRAIN_LABEL_KEY] = labels features[rconst.TRAIN_LABEL_KEY] = labels
...@@ -106,7 +105,6 @@ def create_dataset_from_data_producer(producer, params): ...@@ -106,7 +105,6 @@ def create_dataset_from_data_producer(producer, params):
Processed evaluation features. Processed evaluation features.
""" """
labels = tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool) labels = tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
labels = tf.expand_dims(labels, -1)
fake_valid_pt_mask = tf.cast( fake_valid_pt_mask = tf.cast(
tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool) tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
...@@ -134,9 +132,13 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None): ...@@ -134,9 +132,13 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
Returns: Returns:
(training dataset, evaluation dataset, train steps per epoch, (training dataset, evaluation dataset, train steps per epoch,
eval steps per epoch) eval steps per epoch)
"""
Raises:
ValueError: If data is being generated online for when using TPU's.
"""
if params["train_dataset_path"]: if params["train_dataset_path"]:
assert params["eval_dataset_path"]
train_dataset = create_dataset_from_tf_record_files( train_dataset = create_dataset_from_tf_record_files(
params["train_dataset_path"], params["train_dataset_path"],
input_meta_data["train_prebatch_size"], input_meta_data["train_prebatch_size"],
...@@ -148,34 +150,18 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None): ...@@ -148,34 +150,18 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
params["eval_batch_size"], params["eval_batch_size"],
is_training=False) is_training=False)
# TODO(b/259377621): Remove number of devices (i.e. num_train_steps = int(input_meta_data["num_train_steps"])
# params["batches_per_step"]) in input pipeline logic and only use num_eval_steps = int(input_meta_data["num_eval_steps"])
# global batch size instead.
num_train_steps = int(
np.ceil(input_meta_data["num_train_steps"] /
params["batches_per_step"]))
num_eval_steps = (
input_meta_data["num_eval_steps"] // params["batches_per_step"])
else: else:
assert producer if params["use_tpu"]:
raise ValueError("TPU training does not support data producer yet. "
"Use pre-processed data.")
assert producer
# Start retrieving data from producer. # Start retrieving data from producer.
train_dataset, eval_dataset = create_dataset_from_data_producer( train_dataset, eval_dataset = create_dataset_from_data_producer(
producer, params) producer, params)
num_train_steps = ( num_train_steps = producer.train_batches_per_epoch
producer.train_batches_per_epoch // params["batches_per_step"]) num_eval_steps = producer.eval_batches_per_epoch
num_eval_steps = (
producer.eval_batches_per_epoch // params["batches_per_step"])
assert not producer.train_batches_per_epoch % params["batches_per_step"]
assert not producer.eval_batches_per_epoch % params["batches_per_step"]
# It is required that for distributed training, the dataset must call
# batch(). The parameter of batch() here is the number of replicas involed,
# such that each replica evenly gets a slice of data.
# drop_remainder = True, as we would like batch call to return a fixed shape
# vs None, this prevents a expensive broadcast during weighted_loss
batches_per_step = params["batches_per_step"]
train_dataset = train_dataset.batch(batches_per_step, drop_remainder=True)
eval_dataset = eval_dataset.batch(batches_per_step, drop_remainder=True)
return train_dataset, eval_dataset, num_train_steps, num_eval_steps return train_dataset, eval_dataset, num_train_steps, num_eval_steps
...@@ -42,14 +42,14 @@ from official.utils.logs import mlperf_helper ...@@ -42,14 +42,14 @@ from official.utils.logs import mlperf_helper
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def metric_fn(logits, dup_mask, params): def metric_fn(logits, dup_mask, params):
dup_mask = tf.cast(dup_mask, tf.float32) dup_mask = tf.cast(dup_mask, tf.float32)
logits = tf.slice(logits, [0, 0, 1], [-1, -1, -1]) logits = tf.slice(logits, [0, 1], [-1, -1])
in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg( in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
logits, logits,
dup_mask, dup_mask,
...@@ -73,6 +73,24 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -73,6 +73,24 @@ class MetricLayer(tf.keras.layers.Layer):
return logits return logits
class LossLayer(tf.keras.layers.Layer):
"""Pass-through loss layer for NCF model."""
def __init__(self, loss_normalization_factor):
super(LossLayer, self).__init__()
self.loss_normalization_factor = loss_normalization_factor
self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="sum")
def call(self, inputs):
logits, labels, valid_pt_mask_input = inputs
loss = self.loss(
y_true=labels, y_pred=logits, sample_weight=valid_pt_mask_input)
loss = loss * (1.0 / self.loss_normalization_factor)
self.add_loss(loss)
return logits
class IncrementEpochCallback(tf.keras.callbacks.Callback): class IncrementEpochCallback(tf.keras.callbacks.Callback):
"""A callback to increase the requested epoch for the data producer. """A callback to increase the requested epoch for the data producer.
...@@ -122,48 +140,24 @@ def _get_keras_model(params): ...@@ -122,48 +140,24 @@ def _get_keras_model(params):
"""Constructs and returns the model.""" """Constructs and returns the model."""
batch_size = params["batch_size"] batch_size = params["batch_size"]
# The input layers are of shape (1, batch_size), to match the size of the
# input data. The first dimension is needed because the input data are
# required to be batched to use distribution strategies, and in this case, it
# is designed to be of batch_size 1 for each replica.
user_input = tf.keras.layers.Input( user_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=movielens.USER_COLUMN, dtype=tf.int32)
batch_size=params["batches_per_step"],
name=movielens.USER_COLUMN,
dtype=tf.int32)
item_input = tf.keras.layers.Input( item_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=movielens.ITEM_COLUMN, dtype=tf.int32)
batch_size=params["batches_per_step"],
name=movielens.ITEM_COLUMN,
dtype=tf.int32)
valid_pt_mask_input = tf.keras.layers.Input( valid_pt_mask_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=rconst.VALID_POINT_MASK, dtype=tf.bool)
batch_size=params["batches_per_step"],
name=rconst.VALID_POINT_MASK,
dtype=tf.bool)
dup_mask_input = tf.keras.layers.Input( dup_mask_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=rconst.DUPLICATE_MASK, dtype=tf.int32)
batch_size=params["batches_per_step"],
name=rconst.DUPLICATE_MASK,
dtype=tf.int32)
label_input = tf.keras.layers.Input( label_input = tf.keras.layers.Input(
shape=(batch_size, 1), shape=(1,), name=rconst.TRAIN_LABEL_KEY, dtype=tf.bool)
batch_size=params["batches_per_step"],
name=rconst.TRAIN_LABEL_KEY,
dtype=tf.bool)
base_model = neumf_model.construct_model(
user_input, item_input, params, need_strip=True)
base_model_output = base_model.output base_model = neumf_model.construct_model(user_input, item_input, params)
logits = tf.keras.layers.Lambda( logits = base_model.output
lambda x: tf.expand_dims(x, 0),
name="logits")(base_model_output)
zeros = tf.keras.layers.Lambda( zeros = tf.keras.layers.Lambda(
lambda x: x * 0)(logits) lambda x: x * 0)(logits)
...@@ -172,9 +166,14 @@ def _get_keras_model(params): ...@@ -172,9 +166,14 @@ def _get_keras_model(params):
[zeros, logits], [zeros, logits],
axis=-1) axis=-1)
"""CTL does metric calculation as part of eval_step function""" # Custom training loop calculates loss and metric as a part of
# training/evaluation step function.
if not params["keras_use_ctl"]: if not params["keras_use_ctl"]:
softmax_logits = MetricLayer(params)([softmax_logits, dup_mask_input]) softmax_logits = MetricLayer(params)([softmax_logits, dup_mask_input])
# TODO(b/134744680): Use model.add_loss() instead once the API is well
# supported.
softmax_logits = LossLayer(batch_size)(
[softmax_logits, label_input, valid_pt_mask_input])
keras_model = tf.keras.Model( keras_model = tf.keras.Model(
inputs={ inputs={
...@@ -185,15 +184,6 @@ def _get_keras_model(params): ...@@ -185,15 +184,6 @@ def _get_keras_model(params):
rconst.TRAIN_LABEL_KEY: label_input}, rconst.TRAIN_LABEL_KEY: label_input},
outputs=softmax_logits) outputs=softmax_logits)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction="sum")
keras_model.add_loss(loss_obj(
y_true=label_input,
y_pred=softmax_logits,
sample_weight=valid_pt_mask_input) * 1.0 / batch_size)
keras_model.summary() keras_model.summary()
return keras_model return keras_model
...@@ -207,39 +197,28 @@ def run_ncf(_): ...@@ -207,39 +197,28 @@ def run_ncf(_):
print("Setting tf seed") print("Setting tf seed")
tf.random.set_seed(FLAGS.seed) tf.random.set_seed(FLAGS.seed)
# TODO(seemuch): Support different train and eval batch sizes
if FLAGS.eval_batch_size != FLAGS.batch_size:
logging.warning(
"The Keras implementation of NCF currently does not support batch_size "
"!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match "
"batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size)
)
FLAGS.eval_batch_size = FLAGS.batch_size
params = ncf_common.parse_flags(FLAGS) params = ncf_common.parse_flags(FLAGS)
model_helpers.apply_clean(flags.FLAGS) model_helpers.apply_clean(flags.FLAGS)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus) num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
params["distribute_strategy"] = strategy params["distribute_strategy"] = strategy
if not keras_utils.is_v2_0() and strategy is not None: if not keras_utils.is_v2_0() and strategy is not None:
logging.error("NCF Keras only works with distribution strategy in TF 2.0") logging.error("NCF Keras only works with distribution strategy in TF 2.0")
return return
if (params["keras_use_ctl"] and ( if (params["keras_use_ctl"] and (
not keras_utils.is_v2_0() or strategy is None)): not keras_utils.is_v2_0() or strategy is None)):
logging.error( logging.error(
"Custom training loop only works with tensorflow 2.0 and dist strat.") "Custom training loop only works with tensorflow 2.0 and dist strat.")
return return
if params["use_tpu"] and not params["keras_use_ctl"]:
logging.error("Custom training loop must be used when using TPUStrategy.")
return
# ncf_common rounds eval_batch_size (this is needed due to a reshape during
# eval). This carries over that rounding to batch_size as well. This is the
# per device batch size
params["batch_size"] = params["eval_batch_size"]
batch_size = params["batch_size"] batch_size = params["batch_size"]
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
callbacks = [time_callback] callbacks = [time_callback]
...@@ -248,8 +227,7 @@ def run_ncf(_): ...@@ -248,8 +227,7 @@ def run_ncf(_):
if generate_input_online: if generate_input_online:
# Start data producing thread. # Start data producing thread.
num_users, num_items, num_train_steps, num_eval_steps, producer = ( num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
ncf_common.get_inputs(params))
producer.start() producer.start()
per_epoch_callback = IncrementEpochCallback(producer) per_epoch_callback = IncrementEpochCallback(producer)
callbacks.append(per_epoch_callback) callbacks.append(per_epoch_callback)
...@@ -261,150 +239,213 @@ def run_ncf(_): ...@@ -261,150 +239,213 @@ def run_ncf(_):
num_items = input_meta_data["num_items"] num_items = input_meta_data["num_items"]
params["num_users"], params["num_items"] = num_users, num_items params["num_users"], params["num_items"] = num_users, num_items
(train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data))
steps_per_epoch = None if generate_input_online else num_train_steps
if FLAGS.early_stopping: if FLAGS.early_stopping:
early_stopping_callback = CustomEarlyStopping( early_stopping_callback = CustomEarlyStopping(
"val_HR_METRIC", desired_value=FLAGS.hr_threshold) "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback) callbacks.append(early_stopping_callback)
with distribution_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam(
learning_rate=params["learning_rate"],
beta_1=params["beta1"],
beta_2=params["beta2"],
epsilon=params["epsilon"])
if params["keras_use_ctl"]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
reduction="sum",
from_logits=True)
train_input_iterator = strategy.make_dataset_iterator(train_input_dataset)
eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset)
def train_step():
"""Called once per step to train the model."""
def step_fn(features):
"""Computes loss and applied gradient per replica."""
with tf.GradientTape() as tape:
softmax_logits = keras_model(features)
labels = features[rconst.TRAIN_LABEL_KEY]
loss = loss_object(labels, softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync))
grads = tape.gradient(loss, keras_model.trainable_variables)
# Converting gradients to dense form helps in perf on GPU for NCF
grads = neumf_model.sparse_to_dense_grads(
list(zip(grads, keras_model.trainable_variables)))
optimizer.apply_gradients(grads)
return loss
per_replica_losses = strategy.experimental_run(step_fn,
train_input_iterator)
mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss
def eval_step():
"""Called once per eval step to compute eval metrics."""
def step_fn(features):
"""Computes eval metrics per replica."""
softmax_logits = keras_model(features)
in_top_k, metric_weights = metric_fn(
softmax_logits, features[rconst.DUPLICATE_MASK], params)
hr_sum = tf.reduce_sum(in_top_k*metric_weights)
hr_count = tf.reduce_sum(metric_weights)
return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = (
strategy.experimental_run(step_fn, eval_input_iterator))
hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
hr_count = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
return hr_sum, hr_count
if not FLAGS.run_eagerly: with tf.device(tpu_lib.get_primary_cpu_task(params["use_tpu"])):
train_step = tf.function(train_step) (train_input_dataset, eval_input_dataset,
eval_step = tf.function(eval_step) num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
time_callback.on_train_begin() params, producer, input_meta_data))
for epoch in range(FLAGS.train_epochs): steps_per_epoch = None if generate_input_online else num_train_steps
for cb in callbacks:
cb.on_epoch_begin(epoch)
# As NCF dataset is sampled with randomness, not repeating
# data elements in each epoch has significant impact on
# convergence. As so, offline-generated TF record files
# contains all epoch worth of data. Thus we do not need
# to initialize dataset when reading from tf record files.
if generate_input_online:
train_input_iterator.initialize()
train_loss = 0
for step in range(num_train_steps):
time_callback.on_batch_begin(step+epoch*num_train_steps)
train_loss += train_step()
time_callback.on_batch_end(step+epoch*num_train_steps)
train_loss /= num_train_steps
logging.info("Done training epoch %s, epoch loss=%s.",
epoch+1, train_loss)
eval_input_iterator.initialize()
hr_sum = 0
hr_count = 0
for _ in range(num_eval_steps):
step_hr_sum, step_hr_count = eval_step()
hr_sum += step_hr_sum
hr_count += step_hr_count
logging.info("Done eval epoch %s, hr=%s.", epoch+1, hr_sum/hr_count)
if (FLAGS.early_stopping and
float(hr_sum/hr_count) > params["hr_threshold"]):
break
time_callback.on_train_end()
eval_results = [None, hr_sum/hr_count]
else:
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer keras_model = _get_keras_model(params)
# a valid arg for this model. Also remove as a valid flag. optimizer = tf.keras.optimizers.Adam(
if FLAGS.force_v2_in_keras_compile is not None: learning_rate=params["learning_rate"],
keras_model.compile( beta_1=params["beta1"],
optimizer=optimizer, beta_2=params["beta2"],
run_eagerly=FLAGS.run_eagerly, epsilon=params["epsilon"])
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
params,
strategy,
keras_model,
optimizer,
callbacks,
train_input_dataset,
eval_input_dataset,
num_train_steps,
num_eval_steps,
generate_input_online=generate_input_online)
else: else:
keras_model.compile( # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
optimizer=optimizer, # a valid arg for this model. Also remove as a valid flag.
run_eagerly=FLAGS.run_eagerly) if FLAGS.force_v2_in_keras_compile is not None:
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
else:
keras_model.compile(
optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
history = keras_model.fit(
train_input_dataset,
epochs=FLAGS.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_data=eval_input_dataset,
validation_steps=num_eval_steps,
verbose=2)
logging.info("Training done. Start evaluating")
eval_results = keras_model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
logging.info("Keras evaluation is done.")
if history and history.history:
train_history = history.history
train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback)
return stats
def run_ncf_custom_training(params,
strategy,
keras_model,
optimizer,
callbacks,
train_input_dataset,
eval_input_dataset,
num_train_steps,
num_eval_steps,
generate_input_online=True):
"""Runs custom training loop.
Args:
params: Dictionary containing training parameters.
strategy: Distribution strategy to be used for distributed training.
keras_model: Model used for training.
optimizer: Optimizer used for training.
callbacks: Callbacks to be invoked between batches/epochs.
train_input_dataset: tf.data.Dataset used for training.
eval_input_dataset: tf.data.Dataset used for evaluation.
num_train_steps: Total number of steps to run for training.
num_eval_steps: Total number of steps to run for evaluation.
generate_input_online: Whether input data was generated by data producer.
When data is generated by data producer, then train dataset must be
re-initialized after every epoch.
Returns:
A tuple of train loss and a list of training and evaluation results.
"""
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
reduction="sum", from_logits=True)
train_input_iterator = iter(
strategy.experimental_distribute_dataset(train_input_dataset))
history = keras_model.fit( def train_step(train_iterator):
train_input_dataset, """Called once per step to train the model."""
epochs=FLAGS.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_data=eval_input_dataset,
validation_steps=num_eval_steps,
verbose=2)
logging.info("Training done. Start evaluating") def step_fn(features):
"""Computes loss and applied gradient per replica."""
with tf.GradientTape() as tape:
softmax_logits = keras_model(features)
labels = features[rconst.TRAIN_LABEL_KEY]
loss = loss_object(
labels,
softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / params["batch_size"])
grads = tape.gradient(loss, keras_model.trainable_variables)
# Converting gradients to dense form helps in perf on GPU for NCF
grads = neumf_model.sparse_to_dense_grads(
list(zip(grads, keras_model.trainable_variables)))
optimizer.apply_gradients(grads)
return loss
per_replica_losses = strategy.experimental_run_v2(
step_fn, args=(next(train_iterator),))
mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss
def eval_step(eval_iterator):
"""Called once per eval step to compute eval metrics."""
def step_fn(features):
"""Computes eval metrics per replica."""
softmax_logits = keras_model(features)
in_top_k, metric_weights = metric_fn(softmax_logits,
features[rconst.DUPLICATE_MASK],
params)
hr_sum = tf.reduce_sum(in_top_k * metric_weights)
hr_count = tf.reduce_sum(metric_weights)
return hr_sum, hr_count
eval_results = keras_model.evaluate( per_replica_hr_sum, per_replica_hr_count = (
eval_input_dataset, steps=num_eval_steps, verbose=2) strategy.experimental_run_v2(
step_fn, args=(next(eval_iterator),)))
hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
hr_count = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
return hr_sum, hr_count
logging.info("Keras evaluation is done.") if not FLAGS.run_eagerly:
train_step = tf.function(train_step)
eval_step = tf.function(eval_step)
if history and history.history: for callback in callbacks:
train_history = history.history callback.on_train_begin()
train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback) train_loss = 0
return stats for epoch in range(FLAGS.train_epochs):
for cb in callbacks:
cb.on_epoch_begin(epoch)
# As NCF dataset is sampled with randomness, not repeating
# data elements in each epoch has significant impact on
# convergence. As so, offline-generated TF record files
# contains all epoch worth of data. Thus we do not need
# to initialize dataset when reading from tf record files.
if generate_input_online:
train_input_iterator = iter(
strategy.experimental_distribute_dataset(train_input_dataset))
train_loss = 0
for step in range(num_train_steps):
current_step = step + epoch * num_train_steps
for c in callbacks:
c.on_batch_begin(current_step)
train_loss += train_step(train_input_iterator)
for c in callbacks:
c.on_batch_end(current_step)
train_loss /= num_train_steps
logging.info("Done training epoch %s, epoch loss=%s.", epoch + 1,
train_loss)
eval_input_iterator = iter(
strategy.experimental_distribute_dataset(eval_input_dataset))
hr_sum = 0
hr_count = 0
for _ in range(num_eval_steps):
step_hr_sum, step_hr_count = eval_step(eval_input_iterator)
hr_sum += step_hr_sum
hr_count += step_hr_count
logging.info("Done eval epoch %s, hr=%s.", epoch + 1, hr_sum / hr_count)
if (FLAGS.early_stopping and
float(hr_sum / hr_count) > params["hr_threshold"]):
break
for c in callbacks:
c.on_train_end()
return train_loss, [None, hr_sum / hr_count]
def build_stats(loss, eval_result, time_callback): def build_stats(loss, eval_result, time_callback):
...@@ -444,8 +485,6 @@ def main(_): ...@@ -444,8 +485,6 @@ def main(_):
with logger.benchmark_context(FLAGS), \ with logger.benchmark_context(FLAGS), \
mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging): mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging):
mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0]) mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
if FLAGS.tpu:
raise ValueError("NCF in Keras does not support TPU for now")
run_ncf(FLAGS) run_ncf(FLAGS)
......
...@@ -189,7 +189,7 @@ class NcfTest(tf.test.TestCase): ...@@ -189,7 +189,7 @@ class NcfTest(tf.test.TestCase):
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4) 2 * math.log(2) / math.log(4)) / 4)
_BASE_END_TO_END_FLAGS = ['-batch_size', '1024', '-train_epochs', '1'] _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
......
...@@ -109,7 +109,6 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -109,7 +109,6 @@ def neumf_model_fn(features, labels, mode, params):
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.OPT_HP_ADAM_EPSILON, mlperf_helper.ncf_print(key=mlperf_helper.TAGS.OPT_HP_ADAM_EPSILON,
value=params["epsilon"]) value=params["epsilon"])
optimizer = tf.compat.v1.train.AdamOptimizer( optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=params["learning_rate"], learning_rate=params["learning_rate"],
beta1=params["beta1"], beta1=params["beta1"],
...@@ -151,7 +150,7 @@ def _strip_first_and_last_dimension(x, batch_size): ...@@ -151,7 +150,7 @@ def _strip_first_and_last_dimension(x, batch_size):
return tf.reshape(x[0, :], (batch_size,)) return tf.reshape(x[0, :], (batch_size,))
def construct_model(user_input, item_input, params, need_strip=False): def construct_model(user_input, item_input, params):
# type: (tf.Tensor, tf.Tensor, dict) -> tf.keras.Model # type: (tf.Tensor, tf.Tensor, dict) -> tf.keras.Model
"""Initialize NeuMF model. """Initialize NeuMF model.
...@@ -184,34 +183,33 @@ def construct_model(user_input, item_input, params, need_strip=False): ...@@ -184,34 +183,33 @@ def construct_model(user_input, item_input, params, need_strip=False):
# Initializer for embedding layers # Initializer for embedding layers
embedding_initializer = "glorot_uniform" embedding_initializer = "glorot_uniform"
if need_strip: def mf_slice_fn(x):
batch_size = params["batch_size"] x = tf.squeeze(x, [1])
return x[:, :mf_dim]
user_input_reshaped = tf.keras.layers.Lambda(
lambda x: _strip_first_and_last_dimension(
x, batch_size))(user_input)
item_input_reshaped = tf.keras.layers.Lambda( def mlp_slice_fn(x):
lambda x: _strip_first_and_last_dimension( x = tf.squeeze(x, [1])
x, batch_size))(item_input) return x[:, mf_dim:]
# It turns out to be significantly more effecient to store the MF and MLP # It turns out to be significantly more effecient to store the MF and MLP
# embedding portions in the same table, and then slice as needed. # embedding portions in the same table, and then slice as needed.
mf_slice_fn = lambda x: x[:, :mf_dim]
mlp_slice_fn = lambda x: x[:, mf_dim:]
embedding_user = tf.keras.layers.Embedding( embedding_user = tf.keras.layers.Embedding(
num_users, mf_dim + model_layers[0] // 2, num_users,
mf_dim + model_layers[0] // 2,
embeddings_initializer=embedding_initializer, embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization), embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
input_length=1, name="embedding_user")( input_length=1,
user_input_reshaped if need_strip else user_input) name="embedding_user")(
user_input)
embedding_item = tf.keras.layers.Embedding( embedding_item = tf.keras.layers.Embedding(
num_items, mf_dim + model_layers[0] // 2, num_items,
mf_dim + model_layers[0] // 2,
embeddings_initializer=embedding_initializer, embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization), embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
input_length=1, name="embedding_item")( input_length=1,
item_input_reshaped if need_strip else item_input) name="embedding_item")(
item_input)
# GMF part # GMF part
mf_user_latent = tf.keras.layers.Lambda( mf_user_latent = tf.keras.layers.Lambda(
......
...@@ -24,6 +24,8 @@ import random ...@@ -24,6 +24,8 @@ import random
import string import string
import tensorflow as tf import tensorflow as tf
from official.utils.misc import tpu_lib
def _collective_communication(all_reduce_alg): def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg. """Return a CollectiveCommunication based on all_reduce_alg.
...@@ -83,16 +85,18 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -83,16 +85,18 @@ def get_distribution_strategy(distribution_strategy="default",
num_gpus=0, num_gpus=0,
num_workers=1, num_workers=1,
all_reduce_alg=None, all_reduce_alg=None,
num_packs=1): num_packs=1,
tpu_address=None):
"""Return a DistributionStrategy for running the model. """Return a DistributionStrategy for running the model.
Args: Args:
distribution_strategy: a string specifying which distribution strategy to distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are 'off', 'default', 'one_device', 'mirrored', use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', case insensitive. 'off' means 'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
not to use Distribution Strategy; 'default' means to choose from 'off' means not to use Distribution Strategy; 'default' means to choose from
`MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy` `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
according to the number of GPUs and number of workers. according to the number of GPUs and number of workers. 'tpu' means to use
TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
num_workers: Number of workers to run this model. num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing all_reduce_alg: Optional. Specifies which algorithm to use when performing
...@@ -102,12 +106,14 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -102,12 +106,14 @@ def get_distribution_strategy(distribution_strategy="default",
device topology. device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce` num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`. or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not
be None if `distribution_strategy` is set to `tpu`.
Returns: Returns:
tf.distribute.DistibutionStrategy object. tf.distribute.DistibutionStrategy object.
Raises: Raises:
ValueError: if `distribution_strategy` is 'off' or 'one_device' and ValueError: if `distribution_strategy` is 'off' or 'one_device' and
`num_gpus` is larger than 1; or `num_gpus` is negative. `num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
""" """
if num_gpus < 0: if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.") raise ValueError("`num_gpus` can not be negative.")
...@@ -120,6 +126,15 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -120,6 +126,15 @@ def get_distribution_strategy(distribution_strategy="default",
"flag cannot be set to 'off'.".format(num_gpus, num_workers)) "flag cannot be set to 'off'.".format(num_gpus, num_workers))
return None return None
if distribution_strategy == "tpu":
if not tpu_address:
raise ValueError("`tpu_address` must be specified when using "
"TPUStrategy.")
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored": if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy( return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg)) communication=_collective_communication(all_reduce_alg))
......
...@@ -31,3 +31,8 @@ def tpu_initialize(tpu_address): ...@@ -31,3 +31,8 @@ def tpu_initialize(tpu_address):
tf.config.experimental_connect_to_host(cluster_resolver.master()) tf.config.experimental_connect_to_host(cluster_resolver.master())
tf.tpu.experimental.initialize_tpu_system(cluster_resolver) tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver return cluster_resolver
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns remote TPU worker address. No-op for GPU/CPU training."""
return "/job:worker" if use_remote_tpu else ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment