Unverified Commit f2b702a0 authored by Reed's avatar Reed Committed by GitHub
Browse files

Add XLA support to NCF (#5572)

parent bf298439
......@@ -511,8 +511,8 @@ def make_deserialize(params, batch_size, training=False):
items = tf.reshape(tf.decode_raw(
features[movielens.ITEM_COLUMN], tf.uint16), (batch_size,))
if params["use_tpu"]:
items = tf.cast(items, tf.int32) # TPU doesn't allow uint16 infeed.
if params["use_tpu"] or params["use_xla_for_gpu"]:
items = tf.cast(items, tf.int32) # TPU and XLA disallows uint16 infeed.
if not training:
dupe_mask = tf.reshape(tf.cast(tf.decode_raw(
......
......@@ -124,7 +124,8 @@ class BaseTest(tf.test.TestCase):
with g.as_default():
input_fn, record_dir, batch_count = \
data_preprocessing.make_input_fn(ncf_dataset, True)
dataset = input_fn({"batch_size": BATCH_SIZE, "use_tpu": False})
dataset = input_fn({"batch_size": BATCH_SIZE, "use_tpu": False,
"use_xla_for_gpu": False})
first_epoch = self.drain_dataset(dataset=dataset, g=g)
user_inv_map = {v: k for k, v in ncf_dataset.user_map.items()}
item_inv_map = {v: k for k, v in ncf_dataset.item_map.items()}
......
......@@ -37,6 +37,7 @@ from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from tensorflow.contrib.compiler import xla
from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import data_preprocessing
......@@ -105,9 +106,12 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,
distribution = distribution_utils.get_distribution_strategy(num_gpus=num_gpus)
run_config = tf.estimator.RunConfig(train_distribute=distribution)
params["eval_batch_size"] = eval_batch_size
estimator = tf.estimator.Estimator(model_fn=neumf_model.neumf_model_fn,
model_dir=model_dir, config=run_config,
params=params)
model_fn = neumf_model.neumf_model_fn
if params["use_xla_for_gpu"]:
tf.logging.info("Using XLA for GPU for training and evaluation.")
model_fn = xla.estimator_model_fn(model_fn)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
config=run_config, params=params)
return estimator, estimator
......@@ -187,6 +191,7 @@ def run_ncf(_):
"beta2": FLAGS.beta2,
"epsilon": FLAGS.epsilon,
"match_mlperf": FLAGS.ml_perf,
"use_xla_for_gpu": FLAGS.use_xla_for_gpu,
}, batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size)
# Create hooks that log information about the training and metric values
......@@ -409,6 +414,13 @@ def define_ncf_flags():
"not need to be set."
))
flags.DEFINE_bool(
name="use_xla_for_gpu", default=False, help=flags_core.help_wrap(
"If True, use XLA for the model function. Only works when using a "
"GPU. On TPUs, XLA is always used"))
flags.mark_flags_as_mutual_exclusive(["use_xla_for_gpu", "tpu"])
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
......
......@@ -97,7 +97,8 @@ def neumf_model_fn(features, labels, mode, params):
duplicate_mask = tf.cast(features[rconst.DUPLICATE_MASK], tf.float32)
return compute_eval_loss_and_metrics(
logits, softmax_logits, duplicate_mask, params["num_neg"],
params["match_mlperf"], params["use_tpu"])
params["match_mlperf"],
use_tpu_spec=params["use_tpu"] or params["use_xla_for_gpu"])
elif mode == tf.estimator.ModeKeys.TRAIN:
labels = tf.cast(labels, tf.int32)
......@@ -234,7 +235,7 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor
num_training_neg, # type: int
match_mlperf=False, # type: bool
use_tpu=False # type: bool
use_tpu_spec=False # type: bool
):
# type: (...) -> tf.estimator.EstimatorSpec
"""Model evaluation with HR and NDCG metrics.
......@@ -293,7 +294,9 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
match_mlperf: Use the MLPerf reference convention for computing rank.
use_tpu: Should the evaluation be performed on a TPU.
use_tpu_spec: Should a TPUEstimatorSpec be returned instead of an
EstimatorSpec. Required for TPUs and if XLA is done on a GPU. Despite its
name, TPUEstimatorSpecs work with GPUs
Returns:
An EstimatorSpec for evaluation.
......@@ -334,7 +337,7 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
rconst.NDCG_KEY: tf.metrics.mean(ndcg_tensor, weights=weight_tensor),
}
if use_tpu:
if use_tpu_spec:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL, loss=cross_entropy,
eval_metrics=(metric_fn, [in_top_k, ndcg, metric_weights]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment