Unverified Commit 1255d5b9 authored by Shining Sun's avatar Shining Sun Committed by GitHub
Browse files

Add DS support for NCF keras (#6447)

* add ds support for ncf

* remove comments for in_top_k

* avoid expanding the input layers

* resolve comments and fix lint

* Added some comments in code and fix lint

* fix lint

* add some documentation

* add tensorflow imports
parent a721f413
...@@ -31,11 +31,12 @@ import tensorflow as tf ...@@ -31,11 +31,12 @@ import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.datasets import movielens from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.recommendation import constants as rconst
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.logs import mlperf_helper from official.utils.logs import mlperf_helper
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
...@@ -57,7 +58,7 @@ def _get_metric_fn(params): ...@@ -57,7 +58,7 @@ def _get_metric_fn(params):
def metric_fn(y_true, y_pred): def metric_fn(y_true, y_pred):
"""Returns the in_top_k metric.""" """Returns the in_top_k metric."""
softmax_logits = y_pred softmax_logits = y_pred[0, :]
logits = tf.slice(softmax_logits, [0, 1], [batch_size, 1]) logits = tf.slice(softmax_logits, [0, 1], [batch_size, 1])
# The dup mask should be obtained from input data, but we did not yet find # The dup mask should be obtained from input data, but we did not yet find
...@@ -74,8 +75,12 @@ def _get_metric_fn(params): ...@@ -74,8 +75,12 @@ def _get_metric_fn(params):
params["match_mlperf"], params["match_mlperf"],
params["use_xla_for_gpu"])) params["use_xla_for_gpu"]))
is_training = tf.keras.backend.learning_phase()
if isinstance(is_training, int):
is_training = tf.constant(bool(is_training), dtype=tf.bool)
in_top_k = tf.cond( in_top_k = tf.cond(
tf.keras.backend.learning_phase(), is_training,
lambda: tf.zeros(shape=in_top_k.shape, dtype=in_top_k.dtype), lambda: tf.zeros(shape=in_top_k.shape, dtype=in_top_k.dtype),
lambda: in_top_k) lambda: in_top_k)
...@@ -87,11 +92,35 @@ def _get_metric_fn(params): ...@@ -87,11 +92,35 @@ def _get_metric_fn(params):
def _get_train_and_eval_data(producer, params): def _get_train_and_eval_data(producer, params):
"""Returns the datasets for training and evalutating.""" """Returns the datasets for training and evalutating."""
def preprocess_train_input(features, labels):
"""Pre-process the training data.
This is needed because:
- Distributed training does not support extra inputs. The current
implementation does not use the VALID_POINT_MASK in the input, which makes
it extra, so it needs to be removed.
- The label needs to be extended to be used in the loss fn
"""
features.pop(rconst.VALID_POINT_MASK)
labels = tf.expand_dims(labels, -1)
return features, labels
train_input_fn = producer.make_input_fn(is_training=True) train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params) train_input_dataset = train_input_fn(params).map(
preprocess_train_input)
def preprocess_eval_input(features): def preprocess_eval_input(features):
"""Pre-process the eval data.
This is needed because:
- Distributed training does not support extra inputs. The current
implementation does not use the DUPLICATE_MASK in the input, which makes
it extra, so it needs to be removed.
- The label needs to be extended to be used in the loss fn
"""
features.pop(rconst.DUPLICATE_MASK)
labels = tf.zeros_like(features[movielens.USER_COLUMN]) labels = tf.zeros_like(features[movielens.USER_COLUMN])
labels = tf.expand_dims(labels, -1)
return features, labels return features, labels
eval_input_fn = producer.make_input_fn(is_training=False) eval_input_fn = producer.make_input_fn(is_training=False)
...@@ -120,26 +149,36 @@ def _get_keras_model(params): ...@@ -120,26 +149,36 @@ def _get_keras_model(params):
"""Constructs and returns the model.""" """Constructs and returns the model."""
batch_size = params['batch_size'] batch_size = params['batch_size']
# The input layers are of shape (1, batch_size), to match the size of the
# input data. The first dimension is needed because the input data are
# required to be batched to use distribution strategies, and in this case, it
# is designed to be of batch_size 1 for each replica.
user_input = tf.keras.layers.Input( user_input = tf.keras.layers.Input(
shape=(), shape=(batch_size,),
batch_size=batch_size, batch_size=1,
name=movielens.USER_COLUMN, name=movielens.USER_COLUMN,
dtype=rconst.USER_DTYPE) dtype=tf.int32)
item_input = tf.keras.layers.Input( item_input = tf.keras.layers.Input(
shape=(), shape=(batch_size,),
batch_size=batch_size, batch_size=1,
name=movielens.ITEM_COLUMN, name=movielens.ITEM_COLUMN,
dtype=rconst.ITEM_DTYPE) dtype=tf.int32)
base_model = neumf_model.construct_model(
user_input, item_input, params, need_strip=True)
base_model = neumf_model.construct_model(user_input, item_input, params)
base_model_output = base_model.output base_model_output = base_model.output
logits = tf.keras.layers.Lambda(
lambda x: tf.expand_dims(x, 0),
name="logits")(base_model_output)
zeros = tf.keras.layers.Lambda( zeros = tf.keras.layers.Lambda(
lambda x: x * 0)(base_model_output) lambda x: x * 0)(logits)
softmax_logits = tf.keras.layers.concatenate( softmax_logits = tf.keras.layers.concatenate(
[zeros, base_model_output], [zeros, logits],
axis=-1) axis=-1)
keras_model = tf.keras.Model( keras_model = tf.keras.Model(
...@@ -175,33 +214,39 @@ def run_ncf(_): ...@@ -175,33 +214,39 @@ def run_ncf(_):
producer.start() producer.start()
model_helpers.apply_clean(flags.FLAGS) model_helpers.apply_clean(flags.FLAGS)
keras_model = _get_keras_model(params) batches_per_step = params["batches_per_step"]
optimizer = ncf_common.get_optimizer(params) train_input_dataset, eval_input_dataset = _get_train_and_eval_data(producer,
params)
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) # It is required that for distributed training, the dataset must call
# batch(). The parameter of batch() here is the number of replicas involed,
keras_model.compile( # such that each replica evenly gets a slice of data.
loss=_keras_loss, train_input_dataset = train_input_dataset.batch(batches_per_step)
metrics=[_get_metric_fn(params)], eval_input_dataset = eval_input_dataset.batch(batches_per_step)
optimizer=optimizer)
strategy = ncf_common.get_distribution_strategy(params)
train_input_dataset, eval_input_dataset = _get_train_and_eval_data( with distribution_utils.get_strategy_scope(strategy):
producer, params) keras_model = _get_keras_model(params)
optimizer = ncf_common.get_optimizer(params)
history = keras_model.fit( time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
train_input_dataset,
epochs=FLAGS.train_epochs, keras_model.compile(
callbacks=[ loss=_keras_loss,
IncrementEpochCallback(producer), metrics=[_get_metric_fn(params)],
time_callback], optimizer=optimizer)
verbose=2)
history = keras_model.fit(train_input_dataset,
tf.logging.info("Training done. Start evaluating") epochs=FLAGS.train_epochs,
callbacks=[
eval_results = keras_model.evaluate( IncrementEpochCallback(producer),
eval_input_dataset, time_callback],
steps=num_eval_steps, verbose=2)
verbose=2)
tf.logging.info("Training done. Start evaluating")
eval_results = keras_model.evaluate(
eval_input_dataset,
steps=num_eval_steps,
verbose=2)
tf.logging.info("Keras evaluation is done.") tf.logging.info("Keras evaluation is done.")
...@@ -249,9 +294,6 @@ def main(_): ...@@ -249,9 +294,6 @@ def main(_):
mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0]) mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
if FLAGS.tpu: if FLAGS.tpu:
raise ValueError("NCF in Keras does not support TPU for now") raise ValueError("NCF in Keras does not support TPU for now")
if FLAGS.num_gpus > 1:
raise ValueError("NCF in Keras does not support distribution strategies. "
"Please set num_gpus to 1")
run_ncf(FLAGS) run_ncf(FLAGS)
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
import time import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_benchmark from official.resnet.keras import keras_benchmark
......
...@@ -125,7 +125,7 @@ def run(flags_obj): ...@@ -125,7 +125,7 @@ def run(flags_obj):
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus)
strategy_scope = keras_common.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() distribution_utils.set_up_synthetic_data()
......
...@@ -339,24 +339,6 @@ def is_v2_0(): ...@@ -339,24 +339,6 @@ def is_v2_0():
return tf.__version__.startswith('2') return tf.__version__.startswith('2')
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
def _monkey_patch_org_assert_broadcastable(): def _monkey_patch_org_assert_broadcastable():
"""Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA.""" """Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA."""
def no_op_assert_broadcastable(weights, values): def no_op_assert_broadcastable(weights, values):
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
import time import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet.keras import keras_benchmark from official.resnet.keras import keras_benchmark
......
...@@ -124,7 +124,7 @@ def run(flags_obj): ...@@ -124,7 +124,7 @@ def run(flags_obj):
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster()) num_workers=distribution_utils.configure_cluster())
strategy_scope = keras_common.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
......
...@@ -248,3 +248,21 @@ def configure_cluster(worker_hosts=None, task_index=-1): ...@@ -248,3 +248,21 @@ def configure_cluster(worker_hosts=None, task_index=-1):
else: else:
num_workers = 1 num_workers = 1
return num_workers return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment