Commit 8bd9aa11 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make NCF not depend on tf.contrib.

Remove not maintained code path.

PiperOrigin-RevId: 285869559
parent 8d9a16ce
...@@ -39,6 +39,7 @@ from official.recommendation import constants as rconst ...@@ -39,6 +39,7 @@ from official.recommendation import constants as rconst
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import popen_helper from official.recommendation import popen_helper
from official.recommendation import stat_utils from official.recommendation import stat_utils
from tensorflow.python.tpu.datasets import StreamingFilesDataset
SUMMARY_TEMPLATE = """General: SUMMARY_TEMPLATE = """General:
...@@ -286,10 +287,6 @@ class DatasetManager(object): ...@@ -286,10 +287,6 @@ class DatasetManager(object):
file_pattern = os.path.join( file_pattern = os.path.join(
epoch_data_dir, rconst.SHARD_TEMPLATE.format("*")) epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
# TODO(seemuch): remove this contrib import
# pylint: disable=line-too-long
from tensorflow.contrib.tpu.python.tpu.datasets import StreamingFilesDataset
# pylint: enable=line-too-long
dataset = StreamingFilesDataset( dataset = StreamingFilesDataset(
files=file_pattern, worker_job=popen_helper.worker_job(), files=file_pattern, worker_job=popen_helper.worker_job(),
num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1, num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
......
...@@ -94,7 +94,6 @@ def parse_flags(flags_obj): ...@@ -94,7 +94,6 @@ def parse_flags(flags_obj):
"beta2": flags_obj.beta2, "beta2": flags_obj.beta2,
"epsilon": flags_obj.epsilon, "epsilon": flags_obj.epsilon,
"match_mlperf": flags_obj.ml_perf, "match_mlperf": flags_obj.ml_perf,
"use_xla_for_gpu": flags_obj.use_xla_for_gpu,
"epochs_between_evals": FLAGS.epochs_between_evals, "epochs_between_evals": FLAGS.epochs_between_evals,
"keras_use_ctl": flags_obj.keras_use_ctl, "keras_use_ctl": flags_obj.keras_use_ctl,
"hr_threshold": flags_obj.hr_threshold, "hr_threshold": flags_obj.hr_threshold,
...@@ -307,16 +306,6 @@ def define_ncf_flags(): ...@@ -307,16 +306,6 @@ def define_ncf_flags():
return (eval_batch_size is None or return (eval_batch_size is None or
int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES) int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES)
flags.DEFINE_bool(
name="use_xla_for_gpu", default=False, help=flags_core.help_wrap(
"If True, use XLA for the model function. Only works when using a "
"GPU. On TPUs, XLA is always used"))
xla_message = "--use_xla_for_gpu is incompatible with --tpu"
@flags.multi_flags_validator(["use_xla_for_gpu", "tpu"], message=xla_message)
def xla_validator(flag_dict):
return not flag_dict["use_xla_for_gpu"] or not flag_dict["tpu"]
flags.DEFINE_bool( flags.DEFINE_bool(
name="early_stopping", name="early_stopping",
default=False, default=False,
......
...@@ -57,25 +57,19 @@ FLAGS = flags.FLAGS ...@@ -57,25 +57,19 @@ FLAGS = flags.FLAGS
def construct_estimator(model_dir, params): def construct_estimator(model_dir, params):
"""Construct either an Estimator or TPUEstimator for NCF. """Construct either an Estimator for NCF.
Args: Args:
model_dir: The model directory for the estimator model_dir: The model directory for the estimator
params: The params dict for the estimator params: The params dict for the estimator
Returns: Returns:
An Estimator or TPUEstimator. An Estimator.
""" """
distribution = ncf_common.get_v1_distribution_strategy(params) distribution = ncf_common.get_v1_distribution_strategy(params)
run_config = tf.estimator.RunConfig(train_distribute=distribution, run_config = tf.estimator.RunConfig(train_distribute=distribution,
eval_distribute=distribution) eval_distribute=distribution)
model_fn = neumf_model.neumf_model_fn model_fn = neumf_model.neumf_model_fn
if params["use_xla_for_gpu"]:
# TODO(seemuch): remove the contrib imput
from tensorflow.contrib.compiler import xla
logging.info("Using XLA for GPU for training and evaluation.")
model_fn = xla.estimator_model_fn(model_fn)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
config=run_config, params=params) config=run_config, params=params)
return estimator return estimator
......
...@@ -93,7 +93,7 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -93,7 +93,7 @@ def neumf_model_fn(features, labels, mode, params):
duplicate_mask, duplicate_mask,
params["num_neg"], params["num_neg"],
params["match_mlperf"], params["match_mlperf"],
use_tpu_spec=params["use_xla_for_gpu"]) use_tpu_spec=params["use_tpu"])
elif mode == tf.estimator.ModeKeys.TRAIN: elif mode == tf.estimator.ModeKeys.TRAIN:
labels = tf.cast(labels, tf.int32) labels = tf.cast(labels, tf.int32)
...@@ -269,8 +269,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor ...@@ -269,8 +269,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
softmax_logits, softmax_logits,
duplicate_mask, duplicate_mask,
num_training_neg, num_training_neg,
match_mlperf, match_mlperf)
use_tpu_spec)
if use_tpu_spec: if use_tpu_spec:
return tf.estimator.tpu.TPUEstimatorSpec( return tf.estimator.tpu.TPUEstimatorSpec(
...@@ -285,13 +284,13 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor ...@@ -285,13 +284,13 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
) )
def compute_eval_loss_and_metrics_helper(logits, # type: tf.Tensor def compute_eval_loss_and_metrics_helper(
logits, # type: tf.Tensor
softmax_logits, # type: tf.Tensor softmax_logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor duplicate_mask, # type: tf.Tensor
num_training_neg, # type: int num_training_neg, # type: int
match_mlperf=False, # type: bool match_mlperf=False # type: bool
use_tpu_spec=False # type: bool ):
):
"""Model evaluation with HR and NDCG metrics. """Model evaluation with HR and NDCG metrics.
The evaluation protocol is to rank the test interacted item (truth items) The evaluation protocol is to rank the test interacted item (truth items)
...@@ -348,10 +347,6 @@ def compute_eval_loss_and_metrics_helper(logits, # type: tf.Tensor ...@@ -348,10 +347,6 @@ def compute_eval_loss_and_metrics_helper(logits, # type: tf.Tensor
match_mlperf: Use the MLPerf reference convention for computing rank. match_mlperf: Use the MLPerf reference convention for computing rank.
use_tpu_spec: Should a TPUEstimatorSpec be returned instead of an
EstimatorSpec. Required for TPUs and if XLA is done on a GPU. Despite its
name, TPUEstimatorSpecs work with GPUs
Returns: Returns:
cross_entropy: the loss cross_entropy: the loss
metric_fn: the metrics function metric_fn: the metrics function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment