Commit 826eea75 authored by Reed's avatar Reed Committed by Taylor Robie
Browse files

Add --use_while_loop option. (#5653)

parent c4c49d1a
...@@ -58,6 +58,10 @@ HR_KEY = "HR" ...@@ -58,6 +58,10 @@ HR_KEY = "HR"
NDCG_KEY = "NDCG" NDCG_KEY = "NDCG"
DUPLICATE_MASK = "duplicate_mask" DUPLICATE_MASK = "duplicate_mask"
# Metric names
HR_METRIC_NAME = "HR_METRIC"
NDCG_METRIC_NAME = "NDCG_METRIC"
# ============================================================================== # ==============================================================================
# == Subprocess Data Generation ================================================ # == Subprocess Data Generation ================================================
# ============================================================================== # ==============================================================================
......
...@@ -25,6 +25,7 @@ import time ...@@ -25,6 +25,7 @@ import time
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.compiler import xla from tensorflow.contrib.compiler import xla
from official.recommendation import constants as rconst
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.recommendation import neumf_model from official.recommendation import neumf_model
...@@ -58,27 +59,67 @@ class NcfModelRunner(object): ...@@ -58,27 +59,67 @@ class NcfModelRunner(object):
_SHARED_MODEL_PROPERTY_FIELDS) _SHARED_MODEL_PROPERTY_FIELDS)
_EvalModelProperties = namedtuple( # pylint: disable=invalid-name _EvalModelProperties = namedtuple( # pylint: disable=invalid-name
"_EvalModelProperties", _SHARED_MODEL_PROPERTY_FIELDS + ( "_EvalModelProperties", _SHARED_MODEL_PROPERTY_FIELDS + (
# A dict from metric name to (metric, update_op) tuple. # A dict from metric name to metric tensor.
"metrics", "metrics",
# Initializes the metric variables. # Initializes the metric variables.
"metric_initializer",)) "metric_initializer",))
def __init__(self, ncf_dataset, params): def __init__(self, ncf_dataset, params, num_train_steps, num_eval_steps,
use_while_loop):
self._num_train_steps = num_train_steps
self._num_eval_steps = num_eval_steps
self._use_while_loop = use_while_loop
with tf.Graph().as_default() as self._graph: with tf.Graph().as_default() as self._graph:
if params["use_xla_for_gpu"]: if params["use_xla_for_gpu"]:
# The XLA functions we use require resource variables. # The XLA functions we use require resource variables.
tf.enable_resource_variables() tf.enable_resource_variables()
self._ncf_dataset = ncf_dataset self._ncf_dataset = ncf_dataset
self._global_step = tf.train.create_global_step() self._global_step = tf.train.create_global_step()
self._train_model_properties = self._build_model(params, is_training=True) self._train_model_properties = self._build_model(params, num_train_steps,
self._eval_model_properties = self._build_model(params, is_training=False) is_training=True)
self._eval_model_properties = self._build_model(params, num_eval_steps,
is_training=False)
initializer = tf.global_variables_initializer() initializer = tf.global_variables_initializer()
self._graph.finalize() self._graph.finalize()
self._session = tf.Session(graph=self._graph) self._session = tf.Session(graph=self._graph)
self._session.run(initializer) self._session.run(initializer)
def _build_model(self, params, is_training): def _compute_metric_mean(self, metric_name):
"""Computes the mean from a call tf tf.metrics.mean().
tf.metrics.mean() already returns the mean, so normally this call is
unnecessary. But, if tf.metrics.mean() is called inside a tf.while_loop, the
mean cannot be accessed outside the while loop. Calling this function
recomputes the mean from the variables created by tf.metrics.mean(),
allowing the mean to be accessed outside the while loop.
Args:
metric_name: The string passed to the 'name' argument of tf.metrics.mean()
Returns:
The mean of the metric.
"""
metric_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
total_suffix = metric_name + "/total:0"
total_vars = [v for v in metric_vars if v.name.endswith(total_suffix)]
assert len(total_vars) == 1., (
"Found {} metric variables ending with '{}' but expected to find "
"exactly 1. All metric variables: {}".format(
len(total_vars), total_suffix, metric_vars))
total_var = total_vars[0]
count_suffix = metric_name + "/count:0"
count_vars = [v for v in metric_vars if v.name.endswith(count_suffix)]
assert len(count_vars) == 1., (
"Found {} metric variables ending with '{}' but expected to find "
"exactly 1. All metric variables: {}".format(
len(count_vars), count_suffix, metric_vars))
count_var = count_vars[0]
return total_var / count_var
def _build_model(self, params, num_steps, is_training):
"""Builds the NCF model. """Builds the NCF model.
Args: Args:
...@@ -102,26 +143,75 @@ class NcfModelRunner(object): ...@@ -102,26 +143,75 @@ class NcfModelRunner(object):
model_fn = xla.estimator_model_fn(model_fn) model_fn = xla.estimator_model_fn(model_fn)
if is_training: if is_training:
return self._build_train_specific_graph(
iterator, model_fn, params, record_files_placeholder, num_steps)
else:
return self._build_eval_specific_graph(
iterator, model_fn, params, record_files_placeholder, num_steps)
def _build_train_specific_graph(self, iterator, model_fn, params,
record_files_placeholder, num_train_steps):
"""Builds the part of the model that is specific to training."""
def build():
features, labels = iterator.get_next() features, labels = iterator.get_next()
estimator_spec = model_fn( estimator_spec = model_fn(
features, labels, tf.estimator.ModeKeys.TRAIN, params) features, labels, tf.estimator.ModeKeys.TRAIN, params)
with tf.control_dependencies([estimator_spec.train_op]): with tf.control_dependencies([estimator_spec.train_op]):
run_model_op = self._global_step.assign_add(1) run_model_op = self._global_step.assign_add(1)
return self._TrainModelProperties( return run_model_op, estimator_spec.loss
record_files_placeholder, iterator,
estimator_spec.loss, params["batch_size"], run_model_op) if self._use_while_loop:
def body(i):
run_model_op_single_step, _ = build()
with tf.control_dependencies([run_model_op_single_step]):
return i + 1
run_model_op = tf.while_loop(lambda i: i < num_train_steps, body, [0],
parallel_iterations=1)
loss = None
else: else:
run_model_op, loss = build()
return self._TrainModelProperties(
record_files_placeholder, iterator, loss, params["batch_size"],
run_model_op)
def _build_eval_specific_graph(self, iterator, model_fn, params,
record_files_placeholder, num_eval_steps):
"""Builds the part of the model that is specific to evaluation."""
def build():
features = iterator.get_next() features = iterator.get_next()
estimator_spec = model_fn( estimator_spec = model_fn(
features, None, tf.estimator.ModeKeys.EVAL, params) features, None, tf.estimator.ModeKeys.EVAL, params)
run_model_op = tf.group(*(update_op for _, update_op in run_model_op = tf.group(*(update_op for _, update_op in
estimator_spec.eval_metric_ops.values())) estimator_spec.eval_metric_ops.values()))
eval_metric_tensors = {k: tensor for (k, (tensor, _))
in estimator_spec.eval_metric_ops.items()}
return run_model_op, estimator_spec.loss, eval_metric_tensors
if self._use_while_loop:
def body(i):
run_model_op_single_step, _, _ = build()
with tf.control_dependencies([run_model_op_single_step]):
return i + 1
run_model_op = tf.while_loop(lambda i: i < num_eval_steps, body, [0],
parallel_iterations=1)
loss = None
eval_metric_tensors = {
"HR": self._compute_metric_mean(rconst.HR_METRIC_NAME),
"NDCG": self._compute_metric_mean(rconst.NDCG_METRIC_NAME),
}
else:
run_model_op, loss, eval_metric_tensors = build()
metric_initializer = tf.variables_initializer( metric_initializer = tf.variables_initializer(
tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)) tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))
return self._EvalModelProperties( return self._EvalModelProperties(
record_files_placeholder, iterator, estimator_spec.loss, record_files_placeholder, iterator, loss, params["eval_batch_size"],
params["eval_batch_size"], run_model_op, run_model_op, eval_metric_tensors, metric_initializer)
estimator_spec.eval_metric_ops, metric_initializer)
def _train_or_eval(self, model_properties, num_steps, is_training): def _train_or_eval(self, model_properties, num_steps, is_training):
"""Either trains or evaluates, depending on whether `is_training` is True. """Either trains or evaluates, depending on whether `is_training` is True.
...@@ -155,16 +245,21 @@ class NcfModelRunner(object): ...@@ -155,16 +245,21 @@ class NcfModelRunner(object):
self._session.run(model_properties.iterator.initializer, self._session.run(model_properties.iterator.initializer,
initializer_feed_dict) initializer_feed_dict)
fetches = (model_properties.loss, model_properties.run_model_op) fetches = (model_properties.run_model_op,)
if model_properties.loss is not None:
fetches += (model_properties.loss,)
mode = "Train" if is_training else "Eval" mode = "Train" if is_training else "Eval"
start = None start = None
for i in range(num_steps): times_to_run = 1 if self._use_while_loop else num_steps
loss, _, = self._session.run(fetches) for i in range(times_to_run):
fetches_ = self._session.run(fetches)
if i % 100 == 0: if i % 100 == 0:
if start is None: if start is None:
# Only start the timer after 100 steps so there is a warmup. # Only start the timer after 100 steps so there is a warmup.
start = time.time() start = time.time()
start_step = i start_step = i
if model_properties.loss is not None:
_, loss = fetches_
tf.logging.info("{} Loss = {}".format(mode, loss)) tf.logging.info("{} Loss = {}".format(mode, loss))
end = time.time() end = time.time()
if start is not None: if start is not None:
...@@ -173,34 +268,27 @@ class NcfModelRunner(object): ...@@ -173,34 +268,27 @@ class NcfModelRunner(object):
return record_dir return record_dir
def train(self, num_train_steps): def train(self):
"""Trains the graph for a single cycle. """Trains the graph for a single cycle."""
Args:
num_train_steps: The number of steps per cycle to train for.
"""
record_dir = self._train_or_eval(self._train_model_properties, record_dir = self._train_or_eval(self._train_model_properties,
num_train_steps, is_training=True) self._num_train_steps, is_training=True)
if record_dir: if record_dir:
# We delete the record_dir because each cycle, new TFRecords is generated # We delete the record_dir because each cycle, new TFRecords is generated
# by the async process. # by the async process.
tf.gfile.DeleteRecursively(record_dir) tf.gfile.DeleteRecursively(record_dir)
def eval(self, num_eval_steps): def eval(self):
"""Evaluates the graph on the eval data. """Evaluates the graph on the eval data.
Args:
num_eval_steps: The number of steps to evaluate for.
Returns: Returns:
A dict of evaluation results. A dict of evaluation results.
""" """
self._session.run(self._eval_model_properties.metric_initializer) self._session.run(self._eval_model_properties.metric_initializer)
self._train_or_eval(self._eval_model_properties, num_eval_steps, self._train_or_eval(self._eval_model_properties, self._num_eval_steps,
is_training=False) is_training=False)
eval_results = { eval_results = {
'global_step': self._session.run(self._global_step)} 'global_step': self._session.run(self._global_step)}
for key, (val, _) in self._eval_model_properties.metrics.items(): for key, val in self._eval_model_properties.metrics.items():
val_ = self._session.run(val) val_ = self._session.run(val)
tf.logging.info("{} = {}".format(key, self._session.run(val))) tf.logging.info("{} = {}".format(key, self._session.run(val)))
eval_results[key] = val_ eval_results[key] = val_
......
...@@ -211,7 +211,8 @@ def run_ncf(_): ...@@ -211,7 +211,8 @@ def run_ncf(_):
iterations=num_train_steps, params=params, iterations=num_train_steps, params=params,
batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size) batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size)
else: else:
runner = model_runner.NcfModelRunner(ncf_dataset, params) runner = model_runner.NcfModelRunner(ncf_dataset, params, num_train_steps,
num_eval_steps, FLAGS.use_while_loop)
# Create hooks that log information about the training and metric values # Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
...@@ -280,11 +281,11 @@ def run_ncf(_): ...@@ -280,11 +281,11 @@ def run_ncf(_):
steps=num_eval_steps) steps=num_eval_steps)
tf.logging.info("Evaluation complete.") tf.logging.info("Evaluation complete.")
else: else:
runner.train(num_train_steps) runner.train()
tf.logging.info("Beginning evaluation.") tf.logging.info("Beginning evaluation.")
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_START, mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_START,
value=cycle_index) value=cycle_index)
eval_results = runner.eval(num_eval_steps) eval_results = runner.eval()
tf.logging.info("Evaluation complete.") tf.logging.info("Evaluation complete.")
hr = float(eval_results[rconst.HR_KEY]) hr = float(eval_results[rconst.HR_KEY])
ndcg = float(eval_results[rconst.NDCG_KEY]) ndcg = float(eval_results[rconst.NDCG_KEY])
...@@ -501,6 +502,21 @@ def define_ncf_flags(): ...@@ -501,6 +502,21 @@ def define_ncf_flags():
" * Reloading from checkpoints\n" " * Reloading from checkpoints\n"
" * Any hooks specified with --hooks\n")) " * Any hooks specified with --hooks\n"))
flags.DEFINE_bool(
name="use_while_loop", default=None, help=flags_core.help_wrap(
"If set, run an entire epoch in a session.run() call using a "
"TensorFlow while loop. This can improve performance, but will not "
"print out losses throughout the epoch. Requires "
"--use_estimator=false"
))
xla_message = "--use_while_loop requires --use_estimator=false"
@flags.multi_flags_validator(["use_while_loop", "use_estimator"],
message=xla_message)
def while_loop_validator(flag_dict):
return (not flag_dict["use_while_loop"] or
not flag_dict["use_estimator"])
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
......
...@@ -257,6 +257,14 @@ class NcfTest(tf.test.TestCase): ...@@ -257,6 +257,14 @@ class NcfTest(tf.test.TestCase):
flags.FLAGS.ml_perf = True flags.FLAGS.ml_perf = True
ncf_main.main(None) ncf_main.main(None)
@flagsaver.flagsaver(use_estimator=False, use_while_loop=True,
**_BASE_END_TO_END_FLAGS)
@mock.patch.object(data_preprocessing, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_while_loop(self):
ncf_main.main(None)
flags.FLAGS.ml_perf = True
ncf_main.main(None)
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
......
...@@ -404,8 +404,10 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor ...@@ -404,8 +404,10 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
def metric_fn(top_k_tensor, ndcg_tensor, weight_tensor): def metric_fn(top_k_tensor, ndcg_tensor, weight_tensor):
return { return {
rconst.HR_KEY: tf.metrics.mean(top_k_tensor, weights=weight_tensor), rconst.HR_KEY: tf.metrics.mean(top_k_tensor, weights=weight_tensor,
rconst.NDCG_KEY: tf.metrics.mean(ndcg_tensor, weights=weight_tensor), name=rconst.HR_METRIC_NAME),
rconst.NDCG_KEY: tf.metrics.mean(ndcg_tensor, weights=weight_tensor,
name=rconst.NDCG_METRIC_NAME),
} }
if use_tpu_spec: if use_tpu_spec:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment